In [1]:
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

import sys 
sys.path.append('.')


from cgiar.utils import get_dir
from cgiar.model import XCITMultipleMLP
from cgiar.data import CGIARDataset_V4, augmentations

# reduce font size of plots
plt.rcParams.update({'font.size': 8})

In [2]:
SEED=42
LR=1e-4
EPOCHS=30
IMAGE_SIZE=224
INITIAL_SIZE=512
TRAIN_BATCH_SIZE=64
TEST_BATCH_SIZE=32
HIDDEN_SIZE=512
NUM_FOLDS=5
NUM_VIEWS=10

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
DATA_DIR = get_dir('data/')
ARTIFACTS = get_dir('solutions/v10/#1/')

In [5]:
# Load test data frame from csv
X_test = pd.read_csv(DATA_DIR / 'Test.csv')

In [6]:
test_images = CGIARDataset_V4.load_images(X_test, DATA_DIR / "test", INITIAL_SIZE)
test_images = dict([test_images[idx] for idx in range(len(test_images))])

100%|██████████| 8663/8663 [03:01<00:00, 47.73it/s] 


In [7]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE),
    augmentations["RandomEqualize"],
    augmentations["RandomBlur"],
    augmentations["RandomErasing"],
    augmentations["RandomAffine"],
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
fold_idx=0

In [9]:
model = XCITMultipleMLP(
    model_name="xcit_nano_12_p16_224",
    pretrained=True,
    num_mlps=4,
    hidden_size=HIDDEN_SIZE
)

In [10]:
model.load_state_dict(torch.load(ARTIFACTS / f"model_fold_{fold_idx}.pth"))

RuntimeError: Error(s) in loading state_dict for XCITMultipleMLP:
	Missing key(s) in state_dict: "model.cls_token", "model.patch_embed.proj.0.0.weight", "model.patch_embed.proj.0.1.weight", "model.patch_embed.proj.0.1.bias", "model.patch_embed.proj.0.1.running_mean", "model.patch_embed.proj.0.1.running_var", "model.patch_embed.proj.2.0.weight", "model.patch_embed.proj.2.1.weight", "model.patch_embed.proj.2.1.bias", "model.patch_embed.proj.2.1.running_mean", "model.patch_embed.proj.2.1.running_var", "model.patch_embed.proj.4.0.weight", "model.patch_embed.proj.4.1.weight", "model.patch_embed.proj.4.1.bias", "model.patch_embed.proj.4.1.running_mean", "model.patch_embed.proj.4.1.running_var", "model.patch_embed.proj.6.0.weight", "model.patch_embed.proj.6.1.weight", "model.patch_embed.proj.6.1.bias", "model.patch_embed.proj.6.1.running_mean", "model.patch_embed.proj.6.1.running_var", "model.pos_embed.token_projection.weight", "model.pos_embed.token_projection.bias", "model.blocks.0.gamma1", "model.blocks.0.gamma3", "model.blocks.0.gamma2", "model.blocks.0.norm1.weight", "model.blocks.0.norm1.bias", "model.blocks.0.attn.temperature", "model.blocks.0.attn.qkv.weight", "model.blocks.0.attn.qkv.bias", "model.blocks.0.attn.proj.weight", "model.blocks.0.attn.proj.bias", "model.blocks.0.norm3.weight", "model.blocks.0.norm3.bias", "model.blocks.0.local_mp.conv1.weight", "model.blocks.0.local_mp.conv1.bias", "model.blocks.0.local_mp.bn.weight", "model.blocks.0.local_mp.bn.bias", "model.blocks.0.local_mp.bn.running_mean", "model.blocks.0.local_mp.bn.running_var", "model.blocks.0.local_mp.conv2.weight", "model.blocks.0.local_mp.conv2.bias", "model.blocks.0.norm2.weight", "model.blocks.0.norm2.bias", "model.blocks.0.mlp.fc1.weight", "model.blocks.0.mlp.fc1.bias", "model.blocks.0.mlp.fc2.weight", "model.blocks.0.mlp.fc2.bias", "model.blocks.1.gamma1", "model.blocks.1.gamma3", "model.blocks.1.gamma2", "model.blocks.1.norm1.weight", "model.blocks.1.norm1.bias", "model.blocks.1.attn.temperature", "model.blocks.1.attn.qkv.weight", "model.blocks.1.attn.qkv.bias", "model.blocks.1.attn.proj.weight", "model.blocks.1.attn.proj.bias", "model.blocks.1.norm3.weight", "model.blocks.1.norm3.bias", "model.blocks.1.local_mp.conv1.weight", "model.blocks.1.local_mp.conv1.bias", "model.blocks.1.local_mp.bn.weight", "model.blocks.1.local_mp.bn.bias", "model.blocks.1.local_mp.bn.running_mean", "model.blocks.1.local_mp.bn.running_var", "model.blocks.1.local_mp.conv2.weight", "model.blocks.1.local_mp.conv2.bias", "model.blocks.1.norm2.weight", "model.blocks.1.norm2.bias", "model.blocks.1.mlp.fc1.weight", "model.blocks.1.mlp.fc1.bias", "model.blocks.1.mlp.fc2.weight", "model.blocks.1.mlp.fc2.bias", "model.blocks.2.gamma1", "model.blocks.2.gamma3", "model.blocks.2.gamma2", "model.blocks.2.norm1.weight", "model.blocks.2.norm1.bias", "model.blocks.2.attn.temperature", "model.blocks.2.attn.qkv.weight", "model.blocks.2.attn.qkv.bias", "model.blocks.2.attn.proj.weight", "model.blocks.2.attn.proj.bias", "model.blocks.2.norm3.weight", "model.blocks.2.norm3.bias", "model.blocks.2.local_mp.conv1.weight", "model.blocks.2.local_mp.conv1.bias", "model.blocks.2.local_mp.bn.weight", "model.blocks.2.local_mp.bn.bias", "model.blocks.2.local_mp.bn.running_mean", "model.blocks.2.local_mp.bn.running_var", "model.blocks.2.local_mp.conv2.weight", "model.blocks.2.local_mp.conv2.bias", "model.blocks.2.norm2.weight", "model.blocks.2.norm2.bias", "model.blocks.2.mlp.fc1.weight", "model.blocks.2.mlp.fc1.bias", "model.blocks.2.mlp.fc2.weight", "model.blocks.2.mlp.fc2.bias", "model.blocks.3.gamma1", "model.blocks.3.gamma3", "model.blocks.3.gamma2", "model.blocks.3.norm1.weight", "model.blocks.3.norm1.bias", "model.blocks.3.attn.temperature", "model.blocks.3.attn.qkv.weight", "model.blocks.3.attn.qkv.bias", "model.blocks.3.attn.proj.weight", "model.blocks.3.attn.proj.bias", "model.blocks.3.norm3.weight", "model.blocks.3.norm3.bias", "model.blocks.3.local_mp.conv1.weight", "model.blocks.3.local_mp.conv1.bias", "model.blocks.3.local_mp.bn.weight", "model.blocks.3.local_mp.bn.bias", "model.blocks.3.local_mp.bn.running_mean", "model.blocks.3.local_mp.bn.running_var", "model.blocks.3.local_mp.conv2.weight", "model.blocks.3.local_mp.conv2.bias", "model.blocks.3.norm2.weight", "model.blocks.3.norm2.bias", "model.blocks.3.mlp.fc1.weight", "model.blocks.3.mlp.fc1.bias", "model.blocks.3.mlp.fc2.weight", "model.blocks.3.mlp.fc2.bias", "model.blocks.4.gamma1", "model.blocks.4.gamma3", "model.blocks.4.gamma2", "model.blocks.4.norm1.weight", "model.blocks.4.norm1.bias", "model.blocks.4.attn.temperature", "model.blocks.4.attn.qkv.weight", "model.blocks.4.attn.qkv.bias", "model.blocks.4.attn.proj.weight", "model.blocks.4.attn.proj.bias", "model.blocks.4.norm3.weight", "model.blocks.4.norm3.bias", "model.blocks.4.local_mp.conv1.weight", "model.blocks.4.local_mp.conv1.bias", "model.blocks.4.local_mp.bn.weight", "model.blocks.4.local_mp.bn.bias", "model.blocks.4.local_mp.bn.running_mean", "model.blocks.4.local_mp.bn.running_var", "model.blocks.4.local_mp.conv2.weight", "model.blocks.4.local_mp.conv2.bias", "model.blocks.4.norm2.weight", "model.blocks.4.norm2.bias", "model.blocks.4.mlp.fc1.weight", "model.blocks.4.mlp.fc1.bias", "model.blocks.4.mlp.fc2.weight", "model.blocks.4.mlp.fc2.bias", "model.blocks.5.gamma1", "model.blocks.5.gamma3", "model.blocks.5.gamma2", "model.blocks.5.norm1.weight", "model.blocks.5.norm1.bias", "model.blocks.5.attn.temperature", "model.blocks.5.attn.qkv.weight", "model.blocks.5.attn.qkv.bias", "model.blocks.5.attn.proj.weight", "model.blocks.5.attn.proj.bias", "model.blocks.5.norm3.weight", "model.blocks.5.norm3.bias", "model.blocks.5.local_mp.conv1.weight", "model.blocks.5.local_mp.conv1.bias", "model.blocks.5.local_mp.bn.weight", "model.blocks.5.local_mp.bn.bias", "model.blocks.5.local_mp.bn.running_mean", "model.blocks.5.local_mp.bn.running_var", "model.blocks.5.local_mp.conv2.weight", "model.blocks.5.local_mp.conv2.bias", "model.blocks.5.norm2.weight", "model.blocks.5.norm2.bias", "model.blocks.5.mlp.fc1.weight", "model.blocks.5.mlp.fc1.bias", "model.blocks.5.mlp.fc2.weight", "model.blocks.5.mlp.fc2.bias", "model.blocks.6.gamma1", "model.blocks.6.gamma3", "model.blocks.6.gamma2", "model.blocks.6.norm1.weight", "model.blocks.6.norm1.bias", "model.blocks.6.attn.temperature", "model.blocks.6.attn.qkv.weight", "model.blocks.6.attn.qkv.bias", "model.blocks.6.attn.proj.weight", "model.blocks.6.attn.proj.bias", "model.blocks.6.norm3.weight", "model.blocks.6.norm3.bias", "model.blocks.6.local_mp.conv1.weight", "model.blocks.6.local_mp.conv1.bias", "model.blocks.6.local_mp.bn.weight", "model.blocks.6.local_mp.bn.bias", "model.blocks.6.local_mp.bn.running_mean", "model.blocks.6.local_mp.bn.running_var", "model.blocks.6.local_mp.conv2.weight", "model.blocks.6.local_mp.conv2.bias", "model.blocks.6.norm2.weight", "model.blocks.6.norm2.bias", "model.blocks.6.mlp.fc1.weight", "model.blocks.6.mlp.fc1.bias", "model.blocks.6.mlp.fc2.weight", "model.blocks.6.mlp.fc2.bias", "model.blocks.7.gamma1", "model.blocks.7.gamma3", "model.blocks.7.gamma2", "model.blocks.7.norm1.weight", "model.blocks.7.norm1.bias", "model.blocks.7.attn.temperature", "model.blocks.7.attn.qkv.weight", "model.blocks.7.attn.qkv.bias", "model.blocks.7.attn.proj.weight", "model.blocks.7.attn.proj.bias", "model.blocks.7.norm3.weight", "model.blocks.7.norm3.bias", "model.blocks.7.local_mp.conv1.weight", "model.blocks.7.local_mp.conv1.bias", "model.blocks.7.local_mp.bn.weight", "model.blocks.7.local_mp.bn.bias", "model.blocks.7.local_mp.bn.running_mean", "model.blocks.7.local_mp.bn.running_var", "model.blocks.7.local_mp.conv2.weight", "model.blocks.7.local_mp.conv2.bias", "model.blocks.7.norm2.weight", "model.blocks.7.norm2.bias", "model.blocks.7.mlp.fc1.weight", "model.blocks.7.mlp.fc1.bias", "model.blocks.7.mlp.fc2.weight", "model.blocks.7.mlp.fc2.bias", "model.blocks.8.gamma1", "model.blocks.8.gamma3", "model.blocks.8.gamma2", "model.blocks.8.norm1.weight", "model.blocks.8.norm1.bias", "model.blocks.8.attn.temperature", "model.blocks.8.attn.qkv.weight", "model.blocks.8.attn.qkv.bias", "model.blocks.8.attn.proj.weight", "model.blocks.8.attn.proj.bias", "model.blocks.8.norm3.weight", "model.blocks.8.norm3.bias", "model.blocks.8.local_mp.conv1.weight", "model.blocks.8.local_mp.conv1.bias", "model.blocks.8.local_mp.bn.weight", "model.blocks.8.local_mp.bn.bias", "model.blocks.8.local_mp.bn.running_mean", "model.blocks.8.local_mp.bn.running_var", "model.blocks.8.local_mp.conv2.weight", "model.blocks.8.local_mp.conv2.bias", "model.blocks.8.norm2.weight", "model.blocks.8.norm2.bias", "model.blocks.8.mlp.fc1.weight", "model.blocks.8.mlp.fc1.bias", "model.blocks.8.mlp.fc2.weight", "model.blocks.8.mlp.fc2.bias", "model.blocks.9.gamma1", "model.blocks.9.gamma3", "model.blocks.9.gamma2", "model.blocks.9.norm1.weight", "model.blocks.9.norm1.bias", "model.blocks.9.attn.temperature", "model.blocks.9.attn.qkv.weight", "model.blocks.9.attn.qkv.bias", "model.blocks.9.attn.proj.weight", "model.blocks.9.attn.proj.bias", "model.blocks.9.norm3.weight", "model.blocks.9.norm3.bias", "model.blocks.9.local_mp.conv1.weight", "model.blocks.9.local_mp.conv1.bias", "model.blocks.9.local_mp.bn.weight", "model.blocks.9.local_mp.bn.bias", "model.blocks.9.local_mp.bn.running_mean", "model.blocks.9.local_mp.bn.running_var", "model.blocks.9.local_mp.conv2.weight", "model.blocks.9.local_mp.conv2.bias", "model.blocks.9.norm2.weight", "model.blocks.9.norm2.bias", "model.blocks.9.mlp.fc1.weight", "model.blocks.9.mlp.fc1.bias", "model.blocks.9.mlp.fc2.weight", "model.blocks.9.mlp.fc2.bias", "model.blocks.10.gamma1", "model.blocks.10.gamma3", "model.blocks.10.gamma2", "model.blocks.10.norm1.weight", "model.blocks.10.norm1.bias", "model.blocks.10.attn.temperature", "model.blocks.10.attn.qkv.weight", "model.blocks.10.attn.qkv.bias", "model.blocks.10.attn.proj.weight", "model.blocks.10.attn.proj.bias", "model.blocks.10.norm3.weight", "model.blocks.10.norm3.bias", "model.blocks.10.local_mp.conv1.weight", "model.blocks.10.local_mp.conv1.bias", "model.blocks.10.local_mp.bn.weight", "model.blocks.10.local_mp.bn.bias", "model.blocks.10.local_mp.bn.running_mean", "model.blocks.10.local_mp.bn.running_var", "model.blocks.10.local_mp.conv2.weight", "model.blocks.10.local_mp.conv2.bias", "model.blocks.10.norm2.weight", "model.blocks.10.norm2.bias", "model.blocks.10.mlp.fc1.weight", "model.blocks.10.mlp.fc1.bias", "model.blocks.10.mlp.fc2.weight", "model.blocks.10.mlp.fc2.bias", "model.blocks.11.gamma1", "model.blocks.11.gamma3", "model.blocks.11.gamma2", "model.blocks.11.norm1.weight", "model.blocks.11.norm1.bias", "model.blocks.11.attn.temperature", "model.blocks.11.attn.qkv.weight", "model.blocks.11.attn.qkv.bias", "model.blocks.11.attn.proj.weight", "model.blocks.11.attn.proj.bias", "model.blocks.11.norm3.weight", "model.blocks.11.norm3.bias", "model.blocks.11.local_mp.conv1.weight", "model.blocks.11.local_mp.conv1.bias", "model.blocks.11.local_mp.bn.weight", "model.blocks.11.local_mp.bn.bias", "model.blocks.11.local_mp.bn.running_mean", "model.blocks.11.local_mp.bn.running_var", "model.blocks.11.local_mp.conv2.weight", "model.blocks.11.local_mp.conv2.bias", "model.blocks.11.norm2.weight", "model.blocks.11.norm2.bias", "model.blocks.11.mlp.fc1.weight", "model.blocks.11.mlp.fc1.bias", "model.blocks.11.mlp.fc2.weight", "model.blocks.11.mlp.fc2.bias", "model.cls_attn_blocks.0.gamma1", "model.cls_attn_blocks.0.gamma2", "model.cls_attn_blocks.0.norm1.weight", "model.cls_attn_blocks.0.norm1.bias", "model.cls_attn_blocks.0.attn.q.weight", "model.cls_attn_blocks.0.attn.q.bias", "model.cls_attn_blocks.0.attn.k.weight", "model.cls_attn_blocks.0.attn.k.bias", "model.cls_attn_blocks.0.attn.v.weight", "model.cls_attn_blocks.0.attn.v.bias", "model.cls_attn_blocks.0.attn.proj.weight", "model.cls_attn_blocks.0.attn.proj.bias", "model.cls_attn_blocks.0.norm2.weight", "model.cls_attn_blocks.0.norm2.bias", "model.cls_attn_blocks.0.mlp.fc1.weight", "model.cls_attn_blocks.0.mlp.fc1.bias", "model.cls_attn_blocks.0.mlp.fc2.weight", "model.cls_attn_blocks.0.mlp.fc2.bias", "model.cls_attn_blocks.1.gamma1", "model.cls_attn_blocks.1.gamma2", "model.cls_attn_blocks.1.norm1.weight", "model.cls_attn_blocks.1.norm1.bias", "model.cls_attn_blocks.1.attn.q.weight", "model.cls_attn_blocks.1.attn.q.bias", "model.cls_attn_blocks.1.attn.k.weight", "model.cls_attn_blocks.1.attn.k.bias", "model.cls_attn_blocks.1.attn.v.weight", "model.cls_attn_blocks.1.attn.v.bias", "model.cls_attn_blocks.1.attn.proj.weight", "model.cls_attn_blocks.1.attn.proj.bias", "model.cls_attn_blocks.1.norm2.weight", "model.cls_attn_blocks.1.norm2.bias", "model.cls_attn_blocks.1.mlp.fc1.weight", "model.cls_attn_blocks.1.mlp.fc1.bias", "model.cls_attn_blocks.1.mlp.fc2.weight", "model.cls_attn_blocks.1.mlp.fc2.bias", "model.norm.weight", "model.norm.bias", "model.head.0.weight", "model.head.0.bias", "model.head.2.weight", "model.head.2.bias". 
	Unexpected key(s) in state_dict: "model.0.cls_token", "model.0.patch_embed.proj.0.0.weight", "model.0.patch_embed.proj.0.1.weight", "model.0.patch_embed.proj.0.1.bias", "model.0.patch_embed.proj.0.1.running_mean", "model.0.patch_embed.proj.0.1.running_var", "model.0.patch_embed.proj.0.1.num_batches_tracked", "model.0.patch_embed.proj.2.0.weight", "model.0.patch_embed.proj.2.1.weight", "model.0.patch_embed.proj.2.1.bias", "model.0.patch_embed.proj.2.1.running_mean", "model.0.patch_embed.proj.2.1.running_var", "model.0.patch_embed.proj.2.1.num_batches_tracked", "model.0.patch_embed.proj.4.0.weight", "model.0.patch_embed.proj.4.1.weight", "model.0.patch_embed.proj.4.1.bias", "model.0.patch_embed.proj.4.1.running_mean", "model.0.patch_embed.proj.4.1.running_var", "model.0.patch_embed.proj.4.1.num_batches_tracked", "model.0.patch_embed.proj.6.0.weight", "model.0.patch_embed.proj.6.1.weight", "model.0.patch_embed.proj.6.1.bias", "model.0.patch_embed.proj.6.1.running_mean", "model.0.patch_embed.proj.6.1.running_var", "model.0.patch_embed.proj.6.1.num_batches_tracked", "model.0.pos_embed.token_projection.weight", "model.0.pos_embed.token_projection.bias", "model.0.blocks.0.gamma1", "model.0.blocks.0.gamma3", "model.0.blocks.0.gamma2", "model.0.blocks.0.norm1.weight", "model.0.blocks.0.norm1.bias", "model.0.blocks.0.attn.temperature", "model.0.blocks.0.attn.qkv.weight", "model.0.blocks.0.attn.qkv.bias", "model.0.blocks.0.attn.proj.weight", "model.0.blocks.0.attn.proj.bias", "model.0.blocks.0.norm3.weight", "model.0.blocks.0.norm3.bias", "model.0.blocks.0.local_mp.conv1.weight", "model.0.blocks.0.local_mp.conv1.bias", "model.0.blocks.0.local_mp.bn.weight", "model.0.blocks.0.local_mp.bn.bias", "model.0.blocks.0.local_mp.bn.running_mean", "model.0.blocks.0.local_mp.bn.running_var", "model.0.blocks.0.local_mp.bn.num_batches_tracked", "model.0.blocks.0.local_mp.conv2.weight", "model.0.blocks.0.local_mp.conv2.bias", "model.0.blocks.0.norm2.weight", "model.0.blocks.0.norm2.bias", "model.0.blocks.0.mlp.fc1.weight", "model.0.blocks.0.mlp.fc1.bias", "model.0.blocks.0.mlp.fc2.weight", "model.0.blocks.0.mlp.fc2.bias", "model.0.blocks.1.gamma1", "model.0.blocks.1.gamma3", "model.0.blocks.1.gamma2", "model.0.blocks.1.norm1.weight", "model.0.blocks.1.norm1.bias", "model.0.blocks.1.attn.temperature", "model.0.blocks.1.attn.qkv.weight", "model.0.blocks.1.attn.qkv.bias", "model.0.blocks.1.attn.proj.weight", "model.0.blocks.1.attn.proj.bias", "model.0.blocks.1.norm3.weight", "model.0.blocks.1.norm3.bias", "model.0.blocks.1.local_mp.conv1.weight", "model.0.blocks.1.local_mp.conv1.bias", "model.0.blocks.1.local_mp.bn.weight", "model.0.blocks.1.local_mp.bn.bias", "model.0.blocks.1.local_mp.bn.running_mean", "model.0.blocks.1.local_mp.bn.running_var", "model.0.blocks.1.local_mp.bn.num_batches_tracked", "model.0.blocks.1.local_mp.conv2.weight", "model.0.blocks.1.local_mp.conv2.bias", "model.0.blocks.1.norm2.weight", "model.0.blocks.1.norm2.bias", "model.0.blocks.1.mlp.fc1.weight", "model.0.blocks.1.mlp.fc1.bias", "model.0.blocks.1.mlp.fc2.weight", "model.0.blocks.1.mlp.fc2.bias", "model.0.blocks.2.gamma1", "model.0.blocks.2.gamma3", "model.0.blocks.2.gamma2", "model.0.blocks.2.norm1.weight", "model.0.blocks.2.norm1.bias", "model.0.blocks.2.attn.temperature", "model.0.blocks.2.attn.qkv.weight", "model.0.blocks.2.attn.qkv.bias", "model.0.blocks.2.attn.proj.weight", "model.0.blocks.2.attn.proj.bias", "model.0.blocks.2.norm3.weight", "model.0.blocks.2.norm3.bias", "model.0.blocks.2.local_mp.conv1.weight", "model.0.blocks.2.local_mp.conv1.bias", "model.0.blocks.2.local_mp.bn.weight", "model.0.blocks.2.local_mp.bn.bias", "model.0.blocks.2.local_mp.bn.running_mean", "model.0.blocks.2.local_mp.bn.running_var", "model.0.blocks.2.local_mp.bn.num_batches_tracked", "model.0.blocks.2.local_mp.conv2.weight", "model.0.blocks.2.local_mp.conv2.bias", "model.0.blocks.2.norm2.weight", "model.0.blocks.2.norm2.bias", "model.0.blocks.2.mlp.fc1.weight", "model.0.blocks.2.mlp.fc1.bias", "model.0.blocks.2.mlp.fc2.weight", "model.0.blocks.2.mlp.fc2.bias", "model.0.blocks.3.gamma1", "model.0.blocks.3.gamma3", "model.0.blocks.3.gamma2", "model.0.blocks.3.norm1.weight", "model.0.blocks.3.norm1.bias", "model.0.blocks.3.attn.temperature", "model.0.blocks.3.attn.qkv.weight", "model.0.blocks.3.attn.qkv.bias", "model.0.blocks.3.attn.proj.weight", "model.0.blocks.3.attn.proj.bias", "model.0.blocks.3.norm3.weight", "model.0.blocks.3.norm3.bias", "model.0.blocks.3.local_mp.conv1.weight", "model.0.blocks.3.local_mp.conv1.bias", "model.0.blocks.3.local_mp.bn.weight", "model.0.blocks.3.local_mp.bn.bias", "model.0.blocks.3.local_mp.bn.running_mean", "model.0.blocks.3.local_mp.bn.running_var", "model.0.blocks.3.local_mp.bn.num_batches_tracked", "model.0.blocks.3.local_mp.conv2.weight", "model.0.blocks.3.local_mp.conv2.bias", "model.0.blocks.3.norm2.weight", "model.0.blocks.3.norm2.bias", "model.0.blocks.3.mlp.fc1.weight", "model.0.blocks.3.mlp.fc1.bias", "model.0.blocks.3.mlp.fc2.weight", "model.0.blocks.3.mlp.fc2.bias", "model.0.blocks.4.gamma1", "model.0.blocks.4.gamma3", "model.0.blocks.4.gamma2", "model.0.blocks.4.norm1.weight", "model.0.blocks.4.norm1.bias", "model.0.blocks.4.attn.temperature", "model.0.blocks.4.attn.qkv.weight", "model.0.blocks.4.attn.qkv.bias", "model.0.blocks.4.attn.proj.weight", "model.0.blocks.4.attn.proj.bias", "model.0.blocks.4.norm3.weight", "model.0.blocks.4.norm3.bias", "model.0.blocks.4.local_mp.conv1.weight", "model.0.blocks.4.local_mp.conv1.bias", "model.0.blocks.4.local_mp.bn.weight", "model.0.blocks.4.local_mp.bn.bias", "model.0.blocks.4.local_mp.bn.running_mean", "model.0.blocks.4.local_mp.bn.running_var", "model.0.blocks.4.local_mp.bn.num_batches_tracked", "model.0.blocks.4.local_mp.conv2.weight", "model.0.blocks.4.local_mp.conv2.bias", "model.0.blocks.4.norm2.weight", "model.0.blocks.4.norm2.bias", "model.0.blocks.4.mlp.fc1.weight", "model.0.blocks.4.mlp.fc1.bias", "model.0.blocks.4.mlp.fc2.weight", "model.0.blocks.4.mlp.fc2.bias", "model.0.blocks.5.gamma1", "model.0.blocks.5.gamma3", "model.0.blocks.5.gamma2", "model.0.blocks.5.norm1.weight", "model.0.blocks.5.norm1.bias", "model.0.blocks.5.attn.temperature", "model.0.blocks.5.attn.qkv.weight", "model.0.blocks.5.attn.qkv.bias", "model.0.blocks.5.attn.proj.weight", "model.0.blocks.5.attn.proj.bias", "model.0.blocks.5.norm3.weight", "model.0.blocks.5.norm3.bias", "model.0.blocks.5.local_mp.conv1.weight", "model.0.blocks.5.local_mp.conv1.bias", "model.0.blocks.5.local_mp.bn.weight", "model.0.blocks.5.local_mp.bn.bias", "model.0.blocks.5.local_mp.bn.running_mean", "model.0.blocks.5.local_mp.bn.running_var", "model.0.blocks.5.local_mp.bn.num_batches_tracked", "model.0.blocks.5.local_mp.conv2.weight", "model.0.blocks.5.local_mp.conv2.bias", "model.0.blocks.5.norm2.weight", "model.0.blocks.5.norm2.bias", "model.0.blocks.5.mlp.fc1.weight", "model.0.blocks.5.mlp.fc1.bias", "model.0.blocks.5.mlp.fc2.weight", "model.0.blocks.5.mlp.fc2.bias", "model.0.blocks.6.gamma1", "model.0.blocks.6.gamma3", "model.0.blocks.6.gamma2", "model.0.blocks.6.norm1.weight", "model.0.blocks.6.norm1.bias", "model.0.blocks.6.attn.temperature", "model.0.blocks.6.attn.qkv.weight", "model.0.blocks.6.attn.qkv.bias", "model.0.blocks.6.attn.proj.weight", "model.0.blocks.6.attn.proj.bias", "model.0.blocks.6.norm3.weight", "model.0.blocks.6.norm3.bias", "model.0.blocks.6.local_mp.conv1.weight", "model.0.blocks.6.local_mp.conv1.bias", "model.0.blocks.6.local_mp.bn.weight", "model.0.blocks.6.local_mp.bn.bias", "model.0.blocks.6.local_mp.bn.running_mean", "model.0.blocks.6.local_mp.bn.running_var", "model.0.blocks.6.local_mp.bn.num_batches_tracked", "model.0.blocks.6.local_mp.conv2.weight", "model.0.blocks.6.local_mp.conv2.bias", "model.0.blocks.6.norm2.weight", "model.0.blocks.6.norm2.bias", "model.0.blocks.6.mlp.fc1.weight", "model.0.blocks.6.mlp.fc1.bias", "model.0.blocks.6.mlp.fc2.weight", "model.0.blocks.6.mlp.fc2.bias", "model.0.blocks.7.gamma1", "model.0.blocks.7.gamma3", "model.0.blocks.7.gamma2", "model.0.blocks.7.norm1.weight", "model.0.blocks.7.norm1.bias", "model.0.blocks.7.attn.temperature", "model.0.blocks.7.attn.qkv.weight", "model.0.blocks.7.attn.qkv.bias", "model.0.blocks.7.attn.proj.weight", "model.0.blocks.7.attn.proj.bias", "model.0.blocks.7.norm3.weight", "model.0.blocks.7.norm3.bias", "model.0.blocks.7.local_mp.conv1.weight", "model.0.blocks.7.local_mp.conv1.bias", "model.0.blocks.7.local_mp.bn.weight", "model.0.blocks.7.local_mp.bn.bias", "model.0.blocks.7.local_mp.bn.running_mean", "model.0.blocks.7.local_mp.bn.running_var", "model.0.blocks.7.local_mp.bn.num_batches_tracked", "model.0.blocks.7.local_mp.conv2.weight", "model.0.blocks.7.local_mp.conv2.bias", "model.0.blocks.7.norm2.weight", "model.0.blocks.7.norm2.bias", "model.0.blocks.7.mlp.fc1.weight", "model.0.blocks.7.mlp.fc1.bias", "model.0.blocks.7.mlp.fc2.weight", "model.0.blocks.7.mlp.fc2.bias", "model.0.blocks.8.gamma1", "model.0.blocks.8.gamma3", "model.0.blocks.8.gamma2", "model.0.blocks.8.norm1.weight", "model.0.blocks.8.norm1.bias", "model.0.blocks.8.attn.temperature", "model.0.blocks.8.attn.qkv.weight", "model.0.blocks.8.attn.qkv.bias", "model.0.blocks.8.attn.proj.weight", "model.0.blocks.8.attn.proj.bias", "model.0.blocks.8.norm3.weight", "model.0.blocks.8.norm3.bias", "model.0.blocks.8.local_mp.conv1.weight", "model.0.blocks.8.local_mp.conv1.bias", "model.0.blocks.8.local_mp.bn.weight", "model.0.blocks.8.local_mp.bn.bias", "model.0.blocks.8.local_mp.bn.running_mean", "model.0.blocks.8.local_mp.bn.running_var", "model.0.blocks.8.local_mp.bn.num_batches_tracked", "model.0.blocks.8.local_mp.conv2.weight", "model.0.blocks.8.local_mp.conv2.bias", "model.0.blocks.8.norm2.weight", "model.0.blocks.8.norm2.bias", "model.0.blocks.8.mlp.fc1.weight", "model.0.blocks.8.mlp.fc1.bias", "model.0.blocks.8.mlp.fc2.weight", "model.0.blocks.8.mlp.fc2.bias", "model.0.blocks.9.gamma1", "model.0.blocks.9.gamma3", "model.0.blocks.9.gamma2", "model.0.blocks.9.norm1.weight", "model.0.blocks.9.norm1.bias", "model.0.blocks.9.attn.temperature", "model.0.blocks.9.attn.qkv.weight", "model.0.blocks.9.attn.qkv.bias", "model.0.blocks.9.attn.proj.weight", "model.0.blocks.9.attn.proj.bias", "model.0.blocks.9.norm3.weight", "model.0.blocks.9.norm3.bias", "model.0.blocks.9.local_mp.conv1.weight", "model.0.blocks.9.local_mp.conv1.bias", "model.0.blocks.9.local_mp.bn.weight", "model.0.blocks.9.local_mp.bn.bias", "model.0.blocks.9.local_mp.bn.running_mean", "model.0.blocks.9.local_mp.bn.running_var", "model.0.blocks.9.local_mp.bn.num_batches_tracked", "model.0.blocks.9.local_mp.conv2.weight", "model.0.blocks.9.local_mp.conv2.bias", "model.0.blocks.9.norm2.weight", "model.0.blocks.9.norm2.bias", "model.0.blocks.9.mlp.fc1.weight", "model.0.blocks.9.mlp.fc1.bias", "model.0.blocks.9.mlp.fc2.weight", "model.0.blocks.9.mlp.fc2.bias", "model.0.blocks.10.gamma1", "model.0.blocks.10.gamma3", "model.0.blocks.10.gamma2", "model.0.blocks.10.norm1.weight", "model.0.blocks.10.norm1.bias", "model.0.blocks.10.attn.temperature", "model.0.blocks.10.attn.qkv.weight", "model.0.blocks.10.attn.qkv.bias", "model.0.blocks.10.attn.proj.weight", "model.0.blocks.10.attn.proj.bias", "model.0.blocks.10.norm3.weight", "model.0.blocks.10.norm3.bias", "model.0.blocks.10.local_mp.conv1.weight", "model.0.blocks.10.local_mp.conv1.bias", "model.0.blocks.10.local_mp.bn.weight", "model.0.blocks.10.local_mp.bn.bias", "model.0.blocks.10.local_mp.bn.running_mean", "model.0.blocks.10.local_mp.bn.running_var", "model.0.blocks.10.local_mp.bn.num_batches_tracked", "model.0.blocks.10.local_mp.conv2.weight", "model.0.blocks.10.local_mp.conv2.bias", "model.0.blocks.10.norm2.weight", "model.0.blocks.10.norm2.bias", "model.0.blocks.10.mlp.fc1.weight", "model.0.blocks.10.mlp.fc1.bias", "model.0.blocks.10.mlp.fc2.weight", "model.0.blocks.10.mlp.fc2.bias", "model.0.blocks.11.gamma1", "model.0.blocks.11.gamma3", "model.0.blocks.11.gamma2", "model.0.blocks.11.norm1.weight", "model.0.blocks.11.norm1.bias", "model.0.blocks.11.attn.temperature", "model.0.blocks.11.attn.qkv.weight", "model.0.blocks.11.attn.qkv.bias", "model.0.blocks.11.attn.proj.weight", "model.0.blocks.11.attn.proj.bias", "model.0.blocks.11.norm3.weight", "model.0.blocks.11.norm3.bias", "model.0.blocks.11.local_mp.conv1.weight", "model.0.blocks.11.local_mp.conv1.bias", "model.0.blocks.11.local_mp.bn.weight", "model.0.blocks.11.local_mp.bn.bias", "model.0.blocks.11.local_mp.bn.running_mean", "model.0.blocks.11.local_mp.bn.running_var", "model.0.blocks.11.local_mp.bn.num_batches_tracked", "model.0.blocks.11.local_mp.conv2.weight", "model.0.blocks.11.local_mp.conv2.bias", "model.0.blocks.11.norm2.weight", "model.0.blocks.11.norm2.bias", "model.0.blocks.11.mlp.fc1.weight", "model.0.blocks.11.mlp.fc1.bias", "model.0.blocks.11.mlp.fc2.weight", "model.0.blocks.11.mlp.fc2.bias", "model.0.cls_attn_blocks.0.gamma1", "model.0.cls_attn_blocks.0.gamma2", "model.0.cls_attn_blocks.0.norm1.weight", "model.0.cls_attn_blocks.0.norm1.bias", "model.0.cls_attn_blocks.0.attn.q.weight", "model.0.cls_attn_blocks.0.attn.q.bias", "model.0.cls_attn_blocks.0.attn.k.weight", "model.0.cls_attn_blocks.0.attn.k.bias", "model.0.cls_attn_blocks.0.attn.v.weight", "model.0.cls_attn_blocks.0.attn.v.bias", "model.0.cls_attn_blocks.0.attn.proj.weight", "model.0.cls_attn_blocks.0.attn.proj.bias", "model.0.cls_attn_blocks.0.norm2.weight", "model.0.cls_attn_blocks.0.norm2.bias", "model.0.cls_attn_blocks.0.mlp.fc1.weight", "model.0.cls_attn_blocks.0.mlp.fc1.bias", "model.0.cls_attn_blocks.0.mlp.fc2.weight", "model.0.cls_attn_blocks.0.mlp.fc2.bias", "model.0.cls_attn_blocks.1.gamma1", "model.0.cls_attn_blocks.1.gamma2", "model.0.cls_attn_blocks.1.norm1.weight", "model.0.cls_attn_blocks.1.norm1.bias", "model.0.cls_attn_blocks.1.attn.q.weight", "model.0.cls_attn_blocks.1.attn.q.bias", "model.0.cls_attn_blocks.1.attn.k.weight", "model.0.cls_attn_blocks.1.attn.k.bias", "model.0.cls_attn_blocks.1.attn.v.weight", "model.0.cls_attn_blocks.1.attn.v.bias", "model.0.cls_attn_blocks.1.attn.proj.weight", "model.0.cls_attn_blocks.1.attn.proj.bias", "model.0.cls_attn_blocks.1.norm2.weight", "model.0.cls_attn_blocks.1.norm2.bias", "model.0.cls_attn_blocks.1.mlp.fc1.weight", "model.0.cls_attn_blocks.1.mlp.fc1.bias", "model.0.cls_attn_blocks.1.mlp.fc2.weight", "model.0.cls_attn_blocks.1.mlp.fc2.bias", "model.0.norm.weight", "model.0.norm.bias", "model.0.head.weight", "model.0.head.bias", "model.1.weight", "model.1.bias", "model.3.weight", "model.3.bias". 

In [None]:
# Construct test dataset and dataloader
test_dataset = CGIARDataset_V4(
    images=test_images,
    num_views=NUM_VIEWS,
    transform=transform,
    features=X_test,
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=TEST_BATCH_SIZE, 
    shuffle=False
)

In [None]:
predictions = []

In [None]:
with torch.no_grad():
    
    for ids, images_list, growth_stage, season, _ in test_loader:
    
        # average predictions from all the views
        outputs = torch.stack([model((
            growth_stage.to(device).squeeze(),
            season.to(device).squeeze(),
            images.to(device)
        )) for images in images_list]).mean(dim=0)
        
    # get predictions from all the folds
    outputs = outputs.tolist()
    predictions.extend(list(zip(ids, outputs)))

In [None]:
# load the sample submission file and update the extent column with the predictions
submission_df = pd.read_csv('data/SampleSubmission.csv')

# update the extent column with the predictions
submission_df['extent'] = submission_df['ID'].map(dict(predictions))

# save the submission file and trained model
submission_df.to_csv('submission.csv', index=False)