In [1]:
from torch.utils.data import Dataset
import torch
from torchvision import transforms
from PIL import Image
import skimage.io
import numpy as np
import pandas as pd
import pretrainedmodels
from torch.utils.data import DataLoader
import time
import glob
import os
from multiprocessing import Pool
from transforms import *
from torch import nn

# 4 Channel Dataset Object

In [2]:
model_configs = {'polynet':{
                   'input_size': 331,
                   'input_mean': [0.485, 0.456, 0.406, 0.406],
                   'input_std' : [0.229, 0.224, 0.225, 0.225]
                    },
                 'resnet34':{
                   'input_size': 224,
                   'input_mean': [0.485, 0.456, 0.406, 0.406],
                   'input_std' : [0.229, 0.224, 0.225, 0.225]
                    },
                 'senet154':{
                   'input_size': 224,
                   'input_mean': [0.485, 0.456, 0.406, 0.406],
                   'input_std' : [0.229, 0.224, 0.225, 0.225]
                    },
                 'nasnetamobile':{
                   'input_size': 224,
                   'input_mean': [0.5],
                   'input_std' : [0.5]
                    },
                 'bninception':{
                   'input_size': 299,
                   'input_mean': [0.5],
                   'input_std' : [0.5]
                    },
                 'xception':{
                   'input_size': 299,
                   'input_mean': [0.5],
                   'input_std' : [0.5]
                    }
                }

class AtlasData(Dataset):
    def __init__(self, split, train = True, model = 'bninception'):
        self.split = split
        self.train = train
        self.train_str = 'train' if self.train else 'test'
        self.text_file = 'data/atlas_{}_split_{}.txt'.format(self.train_str, self.split)
        
        self.data = [[y for y in x.strip().split(' ')] for x in open(self.text_file, 'r').readlines()]
        self.imgs = [x[0] for x in self.data]
        self.labels = [[int(p) for p in x[1:]] for x in self.data]
        
        self.input_size = model_configs[model]['input_size']
        self.input_mean = model_configs[model]['input_mean']
        self.input_std = model_configs[model]['input_std']
        
        self.transforms = transforms.Compose([GroupRandomRotate(360),
                                              GroupScale(self.input_size),
                                              Stack(roll=False),
                                              ToTorchFormatTensor(div=True),
                                              transforms.Normalize(self.input_mean, self.input_std),
                                            ])
        
    
    
    
    def load_image_stack(self, image_id):
        colors = ['red', 'green', 'blue', 'yellow']
        absolute_paths = ["data/train/{}_{}.png".format(image_id, color) for color in colors]
        images = [Image.open(path).convert('L') for path in absolute_paths]
        
        return images
    
    def __getitem__(self, i):
        image_id = self.imgs[i]
        image = self.load_image_stack(image_id)
        image = self.transforms(image)
        
        label = self.labels[i]
        label_arr = np.zeros(28, dtype = np.float32)
        [np.put(label_arr, x, 1) for x in label]
        
        label_arr = torch.from_numpy(label_arr)
        
        return image, label_arr, label
        
    def __len__(self):
        return len(self.imgs)

# Construct RGBY Model

In [3]:
def construct_rgby_model(model):
    modules = list(model.modules())
    first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
    conv_layer = modules[first_conv_idx]
    container = modules[first_conv_idx - 1]

    params = [x.clone() for x in conv_layer.parameters()]
    kernel_size = params[0].size()
    new_kernel_size = kernel_size[:1] + (4, ) + kernel_size[2:]
    new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()

    new_conv = nn.Conv2d(4, conv_layer.out_channels,
                         conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
                         bias=True if len(params) == 2 else False)
    
    new_conv.weight.data = new_kernels
    if len(params) == 2:
        new_conv.bias.data = params[1].data 
    layer_name = list(container.state_dict().keys())[0][:-7] 

    setattr(container, layer_name, new_conv)
    return model

# Training Script

In [None]:
import pretrainedmodels
from torch.utils.data import DataLoader
import time
from sklearn.metrics import f1_score


def f1_micro(y_true, y_preds, thresh=0.5, eps=1e-20):
    preds_bin = y_preds > thresh # binary representation from probabilities (not relevant)
    truepos = preds_bin * y_true
    
    p = truepos.sum() / (preds_bin.sum() + eps) # take sums and calculate precision on scalars
    r = truepos.sum() / (y_true.sum() + eps) # take sums and calculate recall on scalars
    
    f1 = 2*p*r / (p+r+eps) # we calculate f1 on scalars
    return f1

def f1_macro(y_true, y_preds, thresh=0.5, eps=1e-20):
    preds_bin = y_preds > thresh # binary representation from probabilities (not relevant)
    truepos = preds_bin * y_true

    p = truepos.sum(axis=0) / (preds_bin.sum(axis=0) + eps) # sum along axis=0 (classes)
                                                            # and calculate precision array
    r = truepos.sum(axis=0) / (y_true.sum(axis=0) + eps)    # sum along axis=0 (classes) 
                                                            #  and calculate recall array

    f1 = 2*p*r / (p+r+eps) # we calculate f1 on arrays
    return np.mean(f1)


class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = torch.nn.functional.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()

def train_and_val(model_name, split, batch_size, epochs, lr, start_epoch):
    
    model = pretrainedmodels.__dict__[model_name](num_classes = 1000)
    model = construct_rgby_model(model)
            
    num_features = model.last_linear.in_features 
    model.last_linear = torch.nn.Linear(num_features, 28)
    
    if glob.glob('{}_rgby_0*'.format(model_name)):
        pth_file = torch.load('{}_rgby_0.pth.tar'.format(model_name))
        state_dict = pth_file['state_dict']
        model.load_state_dict(state_dict)
        start_epoch = pth_file['epoch']
        
    model.cuda()

    train_dataset = AtlasData(split = split, train = True, model = model_name)
    val_dataset = AtlasData(split = split, train = False, model = model_name)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
    val_loader = DataLoader(val_dataset, batch_size = 1, shuffle = False)
    
    log_loss = torch.nn.BCEWithLogitsLoss()
    focal_loss = FocalLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)   
    
    
    for epoch in range(start_epoch,epochs+1):
        
        train(model, train_loader, optimizer, log_loss, focal_loss, epoch)
        avg_loss, f1_score = validate(model, val_loader, log_loss, focal_loss, epoch)
                    
        if epoch % 10 == 0:
            filename = '{}_rgby_{}_{}.pth.tar'.format(model_name, split, epoch)
        else:
            filename = '{}_rgby_{}.pth.tar'.format(model_name, split)
            
        state = {'loss': avg_loss, 'f1_score': f1_score, 'epoch': epoch+1, 'state_dict': model.state_dict()}           
        torch.save(state, filename)

    
def train(model, train_loader, optimizer, log_loss, focal_loss, epoch):
    
    model.train()
    start = time.time()
    losses = []
    for i, (images, label_arrs, labels) in enumerate(train_loader):
        images = images.cuda()
        label_arrs = label_arrs.cuda()

        outputs = model(images)

        loss = log_loss(outputs, label_arrs)
        losses.append(loss.data[0])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        end = time.time()
        elapsed = end-start
        
        if i%100==0:
            print("Epoch [{}], Iteration [{}/{}], Loss: {:.4f} ({:.4f}), Elapsed Time {:.4f}"
                .format(epoch, i+1, len(train_loader), loss.data[0], sum(losses)/len(losses), elapsed))
            
    print("Average Loss: {}".format(sum(losses)/len(losses)))
            

def validate(model, val_loader, log_loss, focal_loss, epoch):
    
    model.eval()
    
    losses = []
    y_pred = np.zeros(len(val_loader) * 28).reshape(len(val_loader), 28)
    y_true = np.zeros(len(val_loader) * 28).reshape(len(val_loader), 28)

    for i, (images, label_arrs, labels) in enumerate(val_loader):
        images = images.cuda()
        label_arrs_cuda = label_arrs.cuda()

        raw_predictions = model(images)
        outputs = raw_predictions.data
        
        loss = log_loss(outputs, label_arrs_cuda) 
        losses.append(loss.data)
        
        predictions = np.arange(28)[raw_predictions.data[0] > 0.15]
        
        y_pred[i,:] = predictions
        y_true[i,:] = label_arrs
        
        if sum(predictions) == 0:
            prediction = np.argmax(raw_predictions.detach().cpu().numpy())
            predictions = np.zeros(28)
            np.put(predictions, prediction, 1)
        
        
        if i%1000==0:
            print('Testing {}/{}: Loss {}'.format(i, 
                                                 len(val_loader), 
                                                 sum(losses)/len(losses)))
                                                                         
    score = f1_macro(y_true, y_pred)
    avg_loss = sum(losses)/len(losses)
    print("Avg Loss {}".format(avg_loss))
    print("Score {}".format(score))

    return avg_loss, score

split = 0
batch_size = 16
epochs = 100
lr = 0.0001

model_list = ['xception',
              'nasnetamobile',
              'resnet34', 
              'senet154',
              'polynet',]

start_epoch = 1

for model_name in model_list:
    print("Training {}".format(model_name))
    train_and_val(model_name, split, batch_size, epochs, lr, start_epoch)

Training xception


  "please use transforms.Resize instead.")


Epoch [5], Iteration [1/1289], Loss: 0.1055 (0.1055), Elapsed Time 0.6941
Epoch [5], Iteration [101/1289], Loss: 0.0924 (0.0880), Elapsed Time 72.5745
Epoch [5], Iteration [201/1289], Loss: 0.0971 (0.0867), Elapsed Time 143.4503
Epoch [5], Iteration [301/1289], Loss: 0.0730 (0.0862), Elapsed Time 212.0373
Epoch [5], Iteration [401/1289], Loss: 0.1198 (0.0863), Elapsed Time 278.6110
Epoch [5], Iteration [501/1289], Loss: 0.0817 (0.0870), Elapsed Time 346.4746
Epoch [5], Iteration [601/1289], Loss: 0.1187 (0.0872), Elapsed Time 414.5736
Epoch [5], Iteration [701/1289], Loss: 0.0860 (0.0874), Elapsed Time 481.5677
Epoch [5], Iteration [801/1289], Loss: 0.1073 (0.0873), Elapsed Time 550.4969
Epoch [5], Iteration [901/1289], Loss: 0.0859 (0.0870), Elapsed Time 617.7292
Epoch [5], Iteration [1001/1289], Loss: 0.1071 (0.0873), Elapsed Time 685.4109
Epoch [5], Iteration [1101/1289], Loss: 0.1035 (0.0872), Elapsed Time 752.6763
Epoch [5], Iteration [1201/1289], Loss: 0.0736 (0.0870), Elapsed Ti

Epoch [10], Iteration [401/1289], Loss: 0.0299 (0.0629), Elapsed Time 240.5912
Epoch [10], Iteration [501/1289], Loss: 0.0621 (0.0636), Elapsed Time 300.7936
Epoch [10], Iteration [601/1289], Loss: 0.0801 (0.0639), Elapsed Time 360.1015
Epoch [10], Iteration [701/1289], Loss: 0.0920 (0.0640), Elapsed Time 419.5753
Epoch [10], Iteration [801/1289], Loss: 0.0545 (0.0641), Elapsed Time 480.4198
Epoch [10], Iteration [901/1289], Loss: 0.0655 (0.0646), Elapsed Time 542.0334
Epoch [10], Iteration [1001/1289], Loss: 0.0695 (0.0647), Elapsed Time 600.6816
Epoch [10], Iteration [1101/1289], Loss: 0.0360 (0.0648), Elapsed Time 659.7306
Epoch [10], Iteration [1201/1289], Loss: 0.0707 (0.0649), Elapsed Time 718.5028
Average Loss: 0.06500355899333954
Testing 0/10449: Loss 0.050384100526571274
Testing 1000/10449: Loss 0.0825512483716011
Testing 2000/10449: Loss 0.08099889755249023
Testing 3000/10449: Loss 0.0825745165348053
Testing 4000/10449: Loss 0.08266692608594894
Testing 5000/10449: Loss 0.0831

Epoch [15], Iteration [701/1289], Loss: 0.0385 (0.0448), Elapsed Time 408.5679
Epoch [15], Iteration [801/1289], Loss: 0.0388 (0.0449), Elapsed Time 466.9868
Epoch [15], Iteration [901/1289], Loss: 0.0401 (0.0450), Elapsed Time 524.2537
Epoch [15], Iteration [1001/1289], Loss: 0.0275 (0.0453), Elapsed Time 581.3808
Epoch [15], Iteration [1101/1289], Loss: 0.0478 (0.0454), Elapsed Time 638.6042
Epoch [15], Iteration [1201/1289], Loss: 0.0374 (0.0457), Elapsed Time 695.7257
Average Loss: 0.045806944370269775
Testing 0/10449: Loss 0.073348268866539
Testing 1000/10449: Loss 0.08964750170707703
Testing 2000/10449: Loss 0.08834327012300491
Testing 3000/10449: Loss 0.08951395750045776
Testing 4000/10449: Loss 0.08932550251483917
Testing 5000/10449: Loss 0.08918073028326035
Testing 6000/10449: Loss 0.08889801800251007
Testing 7000/10449: Loss 0.08915096521377563
Testing 8000/10449: Loss 0.08950851112604141
Testing 9000/10449: Loss 0.08950302004814148
Testing 10000/10449: Loss 0.090001285076141

Epoch [20], Iteration [1001/1289], Loss: 0.0398 (0.0304), Elapsed Time 566.4373
Epoch [20], Iteration [1101/1289], Loss: 0.0251 (0.0305), Elapsed Time 623.5984
Epoch [20], Iteration [1201/1289], Loss: 0.0202 (0.0308), Elapsed Time 680.5304
Average Loss: 0.03089461475610733
Testing 0/10449: Loss 0.0531555600464344
Testing 1000/10449: Loss 0.09645521640777588
Testing 2000/10449: Loss 0.09774275869131088
Testing 3000/10449: Loss 0.09842666983604431
Testing 4000/10449: Loss 0.09866738319396973
Testing 5000/10449: Loss 0.09805353730916977
Testing 6000/10449: Loss 0.09841042757034302
Testing 7000/10449: Loss 0.09825305640697479
Testing 8000/10449: Loss 0.09856606274843216
Testing 9000/10449: Loss 0.09851410984992981
Testing 10000/10449: Loss 0.09914117306470871
Avg Loss 0.0992145761847496
Score 0.6434988910021302
Epoch [21], Iteration [1/1289], Loss: 0.0350 (0.0350), Elapsed Time 0.5893
Epoch [21], Iteration [101/1289], Loss: 0.0212 (0.0261), Elapsed Time 57.1232
Epoch [21], Iteration [201/1

Average Loss: 0.020707665011286736
Testing 0/10449: Loss 0.04068291187286377
Testing 1000/10449: Loss 0.11050763726234436
Testing 2000/10449: Loss 0.10879965871572495
Testing 3000/10449: Loss 0.10889080166816711
Testing 4000/10449: Loss 0.10852678120136261
Testing 5000/10449: Loss 0.10794684290885925
Testing 6000/10449: Loss 0.10780800133943558
Testing 7000/10449: Loss 0.10836219042539597
Testing 8000/10449: Loss 0.10963139683008194
Testing 9000/10449: Loss 0.10951690375804901
Testing 10000/10449: Loss 0.10946264863014221
Avg Loss 0.10939693450927734
Score 0.6353056753666605
Epoch [26], Iteration [1/1289], Loss: 0.0111 (0.0111), Elapsed Time 0.6014
Epoch [26], Iteration [101/1289], Loss: 0.0214 (0.0170), Elapsed Time 59.3872
Epoch [26], Iteration [201/1289], Loss: 0.0118 (0.0181), Elapsed Time 116.0597
Epoch [26], Iteration [301/1289], Loss: 0.0351 (0.0186), Elapsed Time 174.3398
Epoch [26], Iteration [401/1289], Loss: 0.0122 (0.0185), Elapsed Time 231.3124
Epoch [26], Iteration [501/1

Testing 3000/10449: Loss 0.1204676479101181
Testing 4000/10449: Loss 0.12003310769796371
Testing 5000/10449: Loss 0.11931385844945908
Testing 6000/10449: Loss 0.11972671002149582
Testing 7000/10449: Loss 0.11975670605897903
Testing 8000/10449: Loss 0.12057864665985107
Testing 9000/10449: Loss 0.12027401477098465
Testing 10000/10449: Loss 0.1203642189502716
Avg Loss 0.12031877040863037
Score 0.6459856467201924
Epoch [31], Iteration [1/1289], Loss: 0.0029 (0.0029), Elapsed Time 0.5836
Epoch [31], Iteration [101/1289], Loss: 0.0127 (0.0134), Elapsed Time 59.0557
Epoch [31], Iteration [201/1289], Loss: 0.0365 (0.0126), Elapsed Time 116.0851
Epoch [31], Iteration [301/1289], Loss: 0.0165 (0.0130), Elapsed Time 173.3989
Epoch [31], Iteration [401/1289], Loss: 0.0179 (0.0129), Elapsed Time 230.1573
Epoch [31], Iteration [501/1289], Loss: 0.0128 (0.0129), Elapsed Time 287.3784
Epoch [31], Iteration [601/1289], Loss: 0.0035 (0.0127), Elapsed Time 344.1567
Epoch [31], Iteration [701/1289], Loss:

Testing 7000/10449: Loss 0.12474077194929123
Testing 8000/10449: Loss 0.1260243058204651
Testing 9000/10449: Loss 0.12636400759220123
Testing 10000/10449: Loss 0.12693214416503906
Avg Loss 0.12701071798801422
Score 0.6385648444723705
Epoch [36], Iteration [1/1289], Loss: 0.0060 (0.0060), Elapsed Time 0.5871
Epoch [36], Iteration [101/1289], Loss: 0.0024 (0.0090), Elapsed Time 59.1563
Epoch [36], Iteration [201/1289], Loss: 0.0101 (0.0089), Elapsed Time 117.5944
Epoch [36], Iteration [301/1289], Loss: 0.0030 (0.0092), Elapsed Time 175.1359
Epoch [36], Iteration [401/1289], Loss: 0.0038 (0.0098), Elapsed Time 233.3639
Epoch [36], Iteration [501/1289], Loss: 0.0056 (0.0098), Elapsed Time 291.0013
Epoch [36], Iteration [601/1289], Loss: 0.0143 (0.0099), Elapsed Time 348.8324
Epoch [36], Iteration [701/1289], Loss: 0.0041 (0.0101), Elapsed Time 407.5749
Epoch [36], Iteration [801/1289], Loss: 0.0051 (0.0100), Elapsed Time 466.1925
Epoch [36], Iteration [901/1289], Loss: 0.0045 (0.0103), Ela

Avg Loss 0.13534699380397797
Score 0.6439986099666418
Epoch [41], Iteration [1/1289], Loss: 0.0027 (0.0027), Elapsed Time 0.5744
Epoch [41], Iteration [101/1289], Loss: 0.0071 (0.0079), Elapsed Time 57.5706
Epoch [41], Iteration [201/1289], Loss: 0.0035 (0.0084), Elapsed Time 114.3131
Epoch [41], Iteration [301/1289], Loss: 0.0051 (0.0081), Elapsed Time 171.0314
Epoch [41], Iteration [401/1289], Loss: 0.0048 (0.0080), Elapsed Time 227.8947
Epoch [41], Iteration [501/1289], Loss: 0.0095 (0.0083), Elapsed Time 284.8739
Epoch [41], Iteration [601/1289], Loss: 0.0220 (0.0085), Elapsed Time 341.6446
Epoch [41], Iteration [701/1289], Loss: 0.0081 (0.0085), Elapsed Time 398.4111
Epoch [41], Iteration [801/1289], Loss: 0.0024 (0.0086), Elapsed Time 454.6770
Epoch [41], Iteration [901/1289], Loss: 0.0049 (0.0088), Elapsed Time 511.0400
Epoch [41], Iteration [1001/1289], Loss: 0.0056 (0.0089), Elapsed Time 567.3522
Epoch [41], Iteration [1101/1289], Loss: 0.0147 (0.0088), Elapsed Time 624.9997
E

In [None]:
model = pretrainedmodels.__dict__['polynet']()

print(model.mean)
print(model.std)
print(model.input_size)

# Generate Predictions

In [8]:
class EvalAtlasData(Dataset):
    def __init__(self, model = 'bninception'):
        self.image_ids = sorted(set([x.split('_')[0] for x in os.listdir('data/test')]))
        
        self.input_size = model_configs[model]['input_size']
        self.input_mean = model_configs[model]['input_mean']
        self.input_std = model_configs[model]['input_std']
        
        self.transforms = transforms.Compose([transforms.Resize(self.input_size),
                                              transforms.ToTensor(),
                                              transforms.Normalize(self.input_mean, self.input_std),
                                            ])
        
        
    def load_image_stack(self, image_id):
        colors = ['red', 'green', 'blue', 'yellow']
        absolute_paths = ["data/test/{}_{}.png".format(image_id, color) for color in colors]
        
        images = [skimage.io.imread(path) for path in absolute_paths]
        
        image_red = images[0]
        image_green = images[1] + (images[3]/2).astype(np.uint8)
        image_blue = images[2] + (images[3]/2).astype(np.uint8)
        
        final_image = np.stack((image_red, image_green, image_blue), -1)
        to_display = Image.fromarray(final_image)
        return to_display
    
    def dump_image(self, i):
        image = self.load_image_stack(self.image_ids[i])
        image_name = "data/test/{}_{}.png".format(self.image_ids[i], 'stacked')
        image.save(image_name)
        print("Saved", image_name)
        
    def load_image(self, i):
        image_id = self.image_ids[i]
        image_path = "data/test/{}_{}.png".format(image_id, 'stacked')
        image = Image.open(image_path)
        
        return image, image_id
    
    def __len__(self):
        return len(self.image_ids)
        
    def __getitem__(self, i):
        image, image_id = self.load_image(i)
        image = self.transforms(image)
        
        return image_id, image

In [9]:
import torch
import pretrainedmodels
import pandas as pd
from torch.nn import Softmax

def generate_preds(model_name):
    model = pretrainedmodels.__dict__[model_name](num_classes = 1000, pretrained = 'imagenet')
    in_features = model.last_linear.in_features
    model.last_linear = torch.nn.Linear(in_features, 28)
    
    if model_name == 'polynet':
        model = torch.nn.DataParallel(model, device_ids = [0,1,2,3]).cuda()
        model.load_state_dict(torch.load('{}_0.pth.tar'.format(model_name)))
        model = model.eval()
    else:
        model.load_state_dict(torch.load('{}_0.pth.tar'.format(model_name)))
        model = model.eval()
        model.cuda()

    eval_data = EvalAtlasData(model = model_name)
    dataloader = DataLoader(eval_data, 1, False)
    
    preds = []
    for i, (image_id, images) in enumerate(dataloader):
        images = images.cuda()

        raw_predictions = (model(images))
        predictions = np.argwhere(raw_predictions.data[0] > 0.15)
        try:
            num_predictions = len(predictions.data[0])
        except IndexError:
            num_predictions = 0

        print('-----------------------------------------------------')
        print(image_id[0])
        print('Raw Prediction', raw_predictions)
        if num_predictions == 0:
            print('No value passed the threshold')
            predictions = [np.argmax(raw_predictions.detach().cpu().numpy())]
            num_predictions = 1
            print("Prediction:", predictions)
            print("Number of predictions", num_predictions)
        else:
            predictions = predictions.data[0].tolist()
            print("Prediction:", predictions)
            print("Number of predictions", num_predictions)

        predicted = ' '.join('%d' % prediction for prediction in predictions)
        print(image_id[0])
        print(predicted)
        pred = dict(Id = image_id[0], Predicted = predicted)
        preds.append(pred)
        
    df = pd.DataFrame(preds)
    df.to_csv('{}.csv'.format(model_name), index = False)

# Agenda

1. Use Focal Loss https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-fast-ai
2. Somehow use the Y channel

Trained:
1. InceptionV4 100 epochs
2. SE-ResNext 33 epochs
3. PolyNet 33 epochs

In [10]:
import pretrainedmodels
import torch

model = pretrainedmodels.__dict__['vgg19_bn']()
in_features = model.last_linear.in_features

model.last_linear = torch.nn.Linear(in_features, 28)
model.cuda()
model.load_state_dict(torch.load('vgg19_bn_0_40.pth.tar')['state_dict'])

RuntimeError: Error(s) in loading state_dict for VGG:
	Unexpected key(s) in state_dict: "_features.0.weight", "_features.0.bias", "_features.1.weight", "_features.1.bias", "_features.1.running_mean", "_features.1.running_var", "_features.3.weight", "_features.3.bias", "_features.4.weight", "_features.4.bias", "_features.4.running_mean", "_features.4.running_var", "_features.7.weight", "_features.7.bias", "_features.8.weight", "_features.8.bias", "_features.8.running_mean", "_features.8.running_var", "_features.10.weight", "_features.10.bias", "_features.11.weight", "_features.11.bias", "_features.11.running_mean", "_features.11.running_var", "_features.14.weight", "_features.14.bias", "_features.15.weight", "_features.15.bias", "_features.15.running_mean", "_features.15.running_var", "_features.17.weight", "_features.17.bias", "_features.18.weight", "_features.18.bias", "_features.18.running_mean", "_features.18.running_var", "_features.20.weight", "_features.20.bias", "_features.21.weight", "_features.21.bias", "_features.21.running_mean", "_features.21.running_var", "_features.23.weight", "_features.23.bias", "_features.24.weight", "_features.24.bias", "_features.24.running_mean", "_features.24.running_var", "_features.27.weight", "_features.27.bias", "_features.28.weight", "_features.28.bias", "_features.28.running_mean", "_features.28.running_var", "_features.30.weight", "_features.30.bias", "_features.31.weight", "_features.31.bias", "_features.31.running_mean", "_features.31.running_var", "_features.33.weight", "_features.33.bias", "_features.34.weight", "_features.34.bias", "_features.34.running_mean", "_features.34.running_var", "_features.36.weight", "_features.36.bias", "_features.37.weight", "_features.37.bias", "_features.37.running_mean", "_features.37.running_var", "_features.40.weight", "_features.40.bias", "_features.41.weight", "_features.41.bias", "_features.41.running_mean", "_features.41.running_var", "_features.43.weight", "_features.43.bias", "_features.44.weight", "_features.44.bias", "_features.44.running_mean", "_features.44.running_var", "_features.46.weight", "_features.46.bias", "_features.47.weight", "_features.47.bias", "_features.47.running_mean", "_features.47.running_var", "_features.49.weight", "_features.49.bias", "_features.50.weight", "_features.50.bias", "_features.50.running_mean", "_features.50.running_var". 