In [1]:
import os
from tqdm.notebook import tqdm
import gc
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn as nn
import math
import timm
import pandas as pl
import torch
import numpy as np
from torch.amp import GradScaler
import cv2
import random
from tqdm.notebook import tqdm
from torch.autograd import Variable
from skimage.metrics import structural_similarity as ssim

from sklearn.model_selection import KFold

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(228)

In [3]:
pairs = pl.read_csv('/kaggle/input/ioai-contest-2/imgs/pairs_list.csv')
paths_embeds = pl.read_csv('/kaggle/input/ioai-contest-2/imgs/paths_embeds.csv')['image_path']
real_embeds = np.load('/kaggle/input/ioai-contest-2/imgs/real_embeds.npy')

In [4]:
class MCSDataset(torch.utils.data.Dataset):
    def __init__(self, image_path, target, imsize = 112):
        self.image_path = image_path
        self.target = target
        self.image_size = imsize

    def __len__(self):
        return len(self.target)

    def resize(self, img, interp):
        return  cv2.resize(
            img, (self.image_size, self.image_size), interpolation=interp)

    def __getitem__(self, idx):
        path = self.image_path[idx]
        target = self.target[idx]
        img = cv2.imread(f'/kaggle/input/ioai-contest-2/imgs/train/{path}')
        img = cv2.resize(
            img, (self.image_size, self.image_size), interpolation= cv2.INTER_LINEAR)

        img = (img / 255.) - 0.5
        img = np.transpose(img,(2,0,1)).astype(np.float32)
        img = torch.from_numpy(img)
        target = torch.from_numpy(target)

        return img, target

In [5]:
class Model(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = timm.create_model( model_name, global_pool='', num_classes=0, in_chans=3)
        x_ = torch.randn((1,3,112,112))
        shape_ = self.encoder(x_).shape[1]
        self.mapping = nn.Linear(shape_, 512)
        self.norm = nn.BatchNorm1d(512)#nn.BatchNorm1d(512)
    def forward(self, x):
        out = self.norm(self.mapping(self.encoder(x).mean(dim=(2, 3))))
        out = F.normalize(out)
        return out

In [6]:
__all__ = ['xception']

pretrained_settings = {
    'xception': {
        'imagenet': {
            'input_space': 'RGB',
            'input_size': [3, 299, 299],
            'input_range': [0, 1],
            'mean': [0.5, 0.5, 0.5],
            'std': [0.5, 0.5, 0.5],
            'num_classes': 1000,
            'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
        }
    }
}


class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None
        
        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))
        
        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x


class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """
    def __init__(self, num_classes=1000):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        #do relu here

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

        # #------- init weights --------
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        #         m.weight.data.normal_(0, math.sqrt(2. / n))
        #     elif isinstance(m, nn.BatchNorm2d):
        #         m.weight.data.fill_(1)
        #         m.bias.data.zero_()
        # #-----------------------------

    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x


def xception(num_classes=1000, pretrained='imagenet', finetune=False):
    model = Xception(num_classes=num_classes)
    if pretrained:
        settings = pretrained_settings['xception'][pretrained]
        #assert num_classes == settings['num_classes'], \
        #    "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)

        model = Xception(num_classes=1000)
        model.load_state_dict(torch.load(settings['url']))
        #model.load_state_dict(model_zoo.load_url(settings['url']))

        model.input_space = settings['input_space']
        model.input_size = settings['input_size']
        model.input_range = settings['input_range']
        model.mean = settings['mean']
        model.std = settings['std']
        
        if finetune:
            for param in model.parameters():
                param.requires_grad = False
                
        model.fc = nn.Linear(2048, num_classes)
        

    # TODO: ugly
    model.last_linear = model.fc
    del model.fc
    return model


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=512):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.fc_bn = nn.BatchNorm1d(512)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        out = self.fc_bn(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

In [7]:
def get_model(model_name, checkpoint_path):
    '''
    Model architecture choosing
    '''
    if model_name == 'ResNet50':
        net = ResNet50()
    elif model_name == 'Xception':
        net = xception(pretrained=False, num_classes=512)
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['net'])
    return net

In [8]:
def make_predict(model, val_loader, val_target, loss_func, DEVICE = 'cuda'):
    preds = []
    model.eval()
    average_loss = 0
    with torch.no_grad():
        for batch_number,  (img, target)  in enumerate(val_loader):
            img = img.to(DEVICE)
            target = target.to(DEVICE)

            with torch.amp.autocast('cuda'):
                outputs = model(img)
                loss = loss_func(outputs, target)

            average_loss += loss.cpu().detach().numpy()
            preds += [outputs.to('cpu').numpy()]
    preds = np.concatenate(preds)
    score = ((preds -  np.array(val_target)) ** 2).mean()
    print('MSE: ', ((preds -  np.array(val_target)) ** 2).mean())
    return score

In [9]:
train_paths = []
train_targets = []
val_paths = []
val_targets = []

for train_idxs, val_idxs in KFold(n_splits=10, random_state=42, shuffle=True).split(np.arange(paths_embeds.shape[0])):
    train_paths.append(paths_embeds.iloc[train_idxs])
    train_targets.append(real_embeds[train_idxs])

    val_paths.append(paths_embeds.iloc[val_idxs])
    val_targets.append(real_embeds[val_idxs])

In [10]:
def train_single_model(train_loader, val_loader, val_target, model, model_name, optimizer, loss_fn, scheduler, epochs=10, DEVICE='cuda'):
    
    scaler = GradScaler('cuda')

    best_model = -1
    best_mse = 10e9

    for epoch in range(epochs):
        model.train()
        average_loss = 0
        tk0 = tqdm(enumerate(train_loader), total = len(train_loader))
        for batch_number,  (img, target)  in tk0:
            optimizer.zero_grad()
            img = img.to(DEVICE)
            target = target.to(DEVICE)
            # continue
            with torch.amp.autocast('cuda'):
                outputs = model(img)
                loss = loss_func(outputs, target)
    
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
    
            average_loss += loss.cpu().detach().numpy()
            tk0.set_postfix(loss=average_loss / (batch_number + 1),lr = scheduler.get_last_lr()[0], stage="train", epoch = epoch)
        val_score = make_predict(model, val_loader, val_target, loss_func)
        if val_score < best_mse:
            best_mse = val_score
            best_model = model
        
    std_m = best_model.state_dict()
    torch.save(std_m, f'hmm_model_student_{model_name}_mse={best_mse}.pt')
    return best_model,best_mse

In [11]:
#models_list = ['resnet18', 'resnet18', 'resnet18', 'resnet18', 'resnet34', 'resnet34', 'resnet34', 'resnet34', 'ResNet50', 'ResNet50']
models_list = [['resnet34'] * 10, ['ResNet50']*10]
model2link = {'ResNet50': '/kaggle/input/tanya_resnet50/pytorch/default/1/best_model_chkpt-resnet50.t7',
             'Xception': '/kaggle/input/xxx/pytorch/default/1/best_model_chkpt-xception.t7'}
models_chkps = []
model_scores = []

for j in range(len(models_list)):
    for i in range(len(models_list[0])):

        print(f'FOLD {i}, model: {models_list[j][i]}')
    
        batch_size = 64
        valid_batch_size = 64
        epochs = 20
        lr = 1e-2
        clip_grad_norm = 15.28
        DEVICE = 'cuda'
        params_train = {'batch_size': batch_size, 'shuffle': True, 'drop_last': True, 'num_workers': 2}
        params_val = {'batch_size': batch_size, 'shuffle': False, 'drop_last': False, 'num_workers': 2}
    
    
        train_loader = torch.utils.data.DataLoader(MCSDataset(train_paths[i].reset_index(drop=True), train_targets[i]), **params_train)
        val_loader = torch.utils.data.DataLoader(MCSDataset(val_paths[i].reset_index(drop=True), val_targets[i]), **params_val)
    
        if models_list[j][i] in model2link:
            model = get_model(models_list[j][i], model2link[models_list[j][i]]).cuda()
        else:
            model = Model(models_list[j][i]).cuda()
        num_train_steps = int(len(train_loader) / batch_size  * epochs)
        loss_func = torch.nn.MSELoss()
        
        optimizer = torch.optim.AdamW(model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * epochs, 1e-6)
    
        best_model, best_score = train_single_model(
            train_loader=train_loader, 
            val_loader=val_loader,
            val_target=val_targets[i],
            model=model,
            model_name=models_list[j][i],
            optimizer=optimizer,
            loss_fn=loss_func,
            scheduler=scheduler,
            epochs=epochs,
            DEVICE=DEVICE
        )
    
        models_chkps.append(best_model)
        model_scores.append(best_score)

FOLD 0, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0025800273


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022789598


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002031942


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018490874


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017616935


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017191465


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016745032


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016404435


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016425938


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016284137


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016073278


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016016006


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001598506


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016009364


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016026885


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016055455


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016100081


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016128536


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016130995


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016136954
FOLD 1, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0026655374


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022227983


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019678203


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018586962


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017453356


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017020128


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016646713


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016405821


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016067942


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015939495


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015920794


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015892664


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015824633


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015844444


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015891981


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.00159143


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015936961


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001596894


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015991562


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015981419
FOLD 2, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0026531518


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002247239


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.00199398


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018631003


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017842314


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017126513


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016497566


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016466024


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016271534


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016043811


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015842037


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015870648


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015807857


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015816765


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001584964


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015883241


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015932146


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015942963


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015954535


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015963636
FOLD 3, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0026009195


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002229469


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019878328


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018802076


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017561178


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016912532


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016623966


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016249747


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001596597


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015885937


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015725379


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015651557


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015672236


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001560071


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015651226


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015700336


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015728152


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015753611


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015756133


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001577307
FOLD 4, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0026022692


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022448502


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020525598


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018730694


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017783489


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016830161


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016548936


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001624633


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001613702


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015898311


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015852768


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015827424


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015786312


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015815194


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015835873


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001588899


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015922044


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015951912


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015955068


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001598375
FOLD 5, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0025795028


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022420944


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002015139


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018917857


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017619237


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016969686


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017026175


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016472467


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016116889


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.00159522


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015811224


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015731127


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015706591


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015694908


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015737832


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015790946


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001582208


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015849693


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015849256


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015860228
FOLD 6, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002610246


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022239177


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020243968


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019019537


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017921514


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017218438


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016782825


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016502056


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016265635


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016268744


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016073401


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001601772


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015964918


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015982578


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016023521


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016039206


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016100919


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016136117


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016145882


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016154072
FOLD 7, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0025899406


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022213017


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020281037


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018619214


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017655573


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016969519


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016643304


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016318369


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016134245


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016034252


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015930622


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015942592


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001585917


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015848791


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015922156


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015941345


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015977138


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015996528


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016034787


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001602664
FOLD 8, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002576225


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022103258


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020274709


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018743408


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017669485


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016991318


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016853876


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016391431


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016338339


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015947428


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015843177


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015744109


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015783856


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015828911


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015811478


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015858441


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015880837


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015926084


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015933461


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015937336
FOLD 9, model: resnet34


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002608087


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0022698531


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002032988


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018826632


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018044312


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016857102


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016675603


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016577082


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016188965


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015961868


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015941426


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015865972


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015810857


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015834077


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015838275


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015873308


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015924332


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015957656


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015975201


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015982811
FOLD 0, model: ResNet50


  checkpoint = torch.load(checkpoint_path)


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001952109


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020061408


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001859752


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017849421


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018119867


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001748503


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001680846


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015681004


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016146264


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015165296


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001475789


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013930702


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013691273


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013267181


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013157815


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012971626


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012930969


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012978441


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013013235


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013026777
FOLD 1, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019822198


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018976534


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001834886


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018496466


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001796164


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017851403


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016889779


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015681745


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015509517


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015089933


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014483599


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014238105


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013564928


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001303231


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001278751


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012650072


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012620558


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012698384


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012763904


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012743884
FOLD 2, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001942759


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019410624


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018278123


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019160295


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019159749


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018718932


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017335182


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017039148


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016262918


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015368016


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015198003


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014285998


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013784104


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013335258


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013019388


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012800455


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012754791


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012796888


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012862474


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012860787
FOLD 3, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0020158081


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018709941


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018269997


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018888912


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001859161


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017650089


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017183626


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015776455


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015416906


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014615005


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014145338


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014045801


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013569233


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013280618


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012959767


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012777114


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012759195


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001285444


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012921739


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012922192
FOLD 4, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001967656


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001853792


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.002021288


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019253498


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019930976


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018626286


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016411185


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016700003


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015271936


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014735785


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014098524


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013770064


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013478047


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013178171


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001283509


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012680852


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012700325


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012721051


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012768547


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012801406
FOLD 5, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019537588


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019100289


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001835206


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017719087


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017581844


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017233994


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016470657


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016876907


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014574472


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014493734


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014315926


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013707705


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001323862


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012914021


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012711985


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012702033


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012689017


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012744647


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012788136


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012809066
FOLD 6, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019570664


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019152879


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018862508


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017917166


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018315329


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018215253


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016937283


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001622851


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015414696


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015393622


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014290197


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013945692


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013460774


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013084888


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013019515


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012754187


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012759959


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012767359


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012821065


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012843239
FOLD 7, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019902813


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019593656


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001913995


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017718017


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018231892


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017410927


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016421712


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016437988


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015087482


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014608592


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014207107


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013520702


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001330263


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012859508


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012636081


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012566049


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012596018


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001260911


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012667052


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012689219
FOLD 8, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019495492


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001991658


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018938275


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018765592


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018834153


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018391503


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017136044


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016929962


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016766678


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015694655


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014924145


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014195145


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013534682


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001308766


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012919698


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012781557


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012740224


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012731206


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012793101


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012810763
FOLD 9, model: ResNet50


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019814717


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019546524


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0018186248


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017465628


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0019107192


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017487409


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0017651153


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015485819


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0016157049


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0015023462


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0014146637


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013993831


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001361229


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0013132963


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.001288199


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012763474


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012779991


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012785539


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012825325


  0%|          | 0/140 [00:00<?, ?it/s]

MSE:  0.0012841644


In [12]:
def read_img(path, image_size = 112):
    img = cv2.imread(f'/kaggle/input/ioai-contest-2/imgs/train/{path}')
    img_ = cv2.resize(
        img, (image_size, image_size), interpolation= cv2.INTER_LINEAR)
    img = (img_ / 255.) - 0.5
    img = np.transpose(img,(2,0,1)).astype(np.float32)
    img = torch.from_numpy(img)
    return img, img_

In [13]:
weights = [1/x for x in model_scores]
models_for_infer = models_chkps

max_iter = 50
loss = nn.MSELoss()
eps = 0.005
attacked_img_dict = {}


for sour, targ in tqdm(zip(pairs['source_imgs'], pairs['target_imgs']),total=len(pairs)):

    target_descriptors = torch.ones((5, 512), dtype=torch.float32)
    targ = targ.split('|')
    sour = sour.split('|')

    list_tagt_img = []
    for i, t in enumerate(targ):
        img, orig_tgt = read_img(t)
        list_tagt_img += [orig_tgt]
        img = img.unsqueeze(0).cuda(non_blocking = True)
        res = model(Variable(img, requires_grad=False)).data.cpu().squeeze()
        target_descriptors[i] = res

    for ii, s in enumerate(sour): 
        img, orig_img = read_img(s)
        img = img.unsqueeze(0).cuda(non_blocking = True)
        input_var = Variable(img, requires_grad=True)
        attacked_img = orig_img
        last_iter = -1
        for iter_number in (range(max_iter)):
            adv_noise = torch.zeros((3,112,112)).cuda(non_blocking = True)
            target_out = Variable(target_descriptors.cuda(non_blocking=True), requires_grad=False)
            input_var.grad = None
            models_for_infer[0].eval()
            out = models_for_infer[0](input_var) * weights[0]
            for i, m in enumerate(models_for_infer[1:]):
                m.eval()
                out += m(input_var) * weights[i+1]
            out /= np.sum(weights)
            calc_loss = loss(out, target_out)
            calc_loss.backward()
            adv_noise = input_var.grad.data.squeeze()
            adv_noise.div_(adv_noise.std())
            adv_noise = eps * torch.clamp(adv_noise, min=-2., max=2.)

            input_var.data = input_var.data - adv_noise
            out_l = model(input_var)
            changed_img = input_var.data.cpu().squeeze()
            changed_img = ((changed_img + 0.5) * 255)
            changed_img[changed_img < 0] = 0
            changed_img[changed_img > 255] = 255
            changed_img = np.transpose(changed_img.numpy(), (1, 2, 0)).astype(np.int16)
            ssim_score = ssim(orig_img, changed_img, channel_axis=2, data_range = 256)
            last_iter = iter_number
            if ssim_score < 0.95:
                break
            else:
                attacked_img = changed_img
        #print(last_iter)
        attacked_img_dict[s] = attacked_img

  0%|          | 0/1000 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


In [14]:
sample_submission = pl.read_csv('/kaggle/input/ioai-contest-2/imgs/sample_submission.csv')
sample_submission_df = pl.DataFrame()
sample_submission_df['Id'] = sample_submission['Id']

result = []
for id_ in tqdm(sample_submission_df['Id']):
    result += [ '|'.join([str(i) for i in attacked_img_dict[id_].flatten().tolist()])  ]
sample_submission_df['Target'] = result
sample_submission_df.to_csv('feel_the_agi.csv', index = None)

  0%|          | 0/5000 [00:00<?, ?it/s]