In [1]:
# Importing the libraries
from finetune import *
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Load data & create dataloaders
def load_data():
    '''
    Different from the one in finetune.py, this function loads test data only
    '''
    data = h5py.File('Dataset_Specific_labelled.h5', 'r')  
    X = data['jet'][...]
    y = data['Y'][...]
    y = np.squeeze(y)
    X = rearrange(X, 'b h w c -> b c h w')
    yt = np.zeros((y.shape[0],2))
    yt[:,0] = 1 - y
    yt[:,1] = y
    y = yt
    _, X_test, _, y_test = train_test_split(X, y, test_size=0.2, random_state=42)    
    testset = TensorDataset(torch.tensor(X_test,dtype=torch.float32), torch.tensor(y_test,dtype=torch.float32))
    testloader = DataLoader(testset, batch_size=512, shuffle=True)
    return testloader

In [3]:
# Load model
model = FineTunedModel()
model.load_state_dict(torch.load('weights.pth'))
model.to(device)

FineTunedModel(
  (model): ViTMAE(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(8, 256, kernel_size=(5, 5), stride=(5, 5))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU(approximate='none')
          (drop1): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (fc2): Linear(in_fe

In [4]:
# Load test data
testloader = load_data()
criterion = nn.CrossEntropyLoss()

In [5]:
# Calculate test accuracy
acc, _ = test_model(testloader, model, criterion, device)
print("Accuracy: ", acc*100, "%")

Accuracy:  78.45 %


## You can pass the test cases by using the following code. 
```python
var = ...
out = model(var)
out.argmax(1)
```

In [6]:
summary(model, (8, 125, 125))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 25, 25]          51,456
          Identity-2             [-1, 625, 256]               0
        PatchEmbed-3             [-1, 625, 256]               0
         LayerNorm-4             [-1, 626, 256]             512
            Linear-5             [-1, 626, 768]         197,376
          Identity-6           [-1, 4, 626, 64]               0
          Identity-7           [-1, 4, 626, 64]               0
            Linear-8             [-1, 626, 256]          65,792
           Dropout-9             [-1, 626, 256]               0
        Attention-10             [-1, 626, 256]               0
         Identity-11             [-1, 626, 256]               0
         Identity-12             [-1, 626, 256]               0
        LayerNorm-13             [-1, 626, 256]             512
           Linear-14            [-1, 62