In [1]:
#!pip install torchsummary
import torch
from torch import nn
from torch.nn import Module
import datetime
import os
import argparse
import random
import csv
import math
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
import time
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, average_precision_score
import torch.nn.functional as F
#from torchsummary import summary

# **Helper functions**

In [2]:
def dot_norm_exp(a,b):
    dot = torch.sum(a * b, dim=1)
    aa = torch.sum((a**2),dim=1)**0.5
    bb = torch.sum((b**2),dim=1)**0.5
    dot_norm = dot/(aa*bb)
    ret = torch.exp(dot_norm)
    return ret

def dot_norm(a,b):
    dot = torch.sum(a * b, dim=1)
    aa = torch.sum((a**2),dim=1)**0.5
    bb = torch.sum((b**2),dim=1)**0.5
    dot_norm = dot/(aa*bb)
    return dot_norm

def dot(a,b):
    dot = torch.sum(a * b, dim=1)
    return dot

def norm_euclidian(a,b):
    aa = (torch.sum((a**2),dim=1)**0.5).unsqueeze(dim=1)
    bb = (torch.sum((b**2),dim=1)**0.5).unsqueeze(dim=1)
    return (torch.sum(((a/aa-b/bb)**2),dim=1)**0.5)

def get_next_model_folder(prefix, path = ''):

    model_folder = lambda prefix, run_idx: f"{prefix}_model_run_{run_idx}"

    run_idx = 1
    while os.path.isdir(os.path.join(path, model_folder(prefix, run_idx))):
        run_idx += 1

    model_path = os.path.join(path, model_folder(prefix, run_idx))
    print(f"STARTING {prefix} RUN {run_idx}! Storing the models at {model_path}")

    return model_path

def get_random_patches(random_patch_loader, num_random_patches):

        is_data_loader_finished = False

        try:
            img_batch = next(iter(random_patch_loader))['image']
        except StopIteration:
            is_data_loader_finished = True
            # random_patch_loader = DataLoader(dataset_train, num_random_patches, shuffle=True)

        if len(img_batch) < num_random_patches:
            is_data_loader_finished = True

        patches = []

        for i in range(num_random_patches):
            x = random.randint(0,6)
            y = random.randint(0,6)

            patches.append(img_batch[i:i+1,:,x*32:x*32+64,y*32:y*32+64])

            # plt.imshow(np.transpose(patches[-1][0],(1,2,0)))
            # plt.show()

        patches_tensor = torch.cat(patches, dim=0)

        return dict(
            patches_tensor = patches_tensor,
            is_data_loader_finished = is_data_loader_finished)

# Tell how many parameters are on the model
def inspect_model(model):
    param_count = 0
    for param_tensor_str in model.state_dict():
        tensor_size = model.state_dict()[param_tensor_str].size()
        print(f"{param_tensor_str} size {tensor_size} = {model.state_dict()[param_tensor_str].numel()} params")
        param_count += model.state_dict()[param_tensor_str].numel()

    print(f"Number of parameters: {param_count}")
    
def get_patch_tensor_from_image_batch(img_batch):

    # Input of the function is a tensor [B, C, H, W]
    # Output of the functions is a tensor [B * 49, C, 64, 64]

    patch_batch = None
    all_patches_list = []

    for y_patch in range(7):
        for x_patch in range(7):

            y1 = y_patch * 32
            y2 = y1 + 64

            x1 = x_patch * 32
            x2 = x1 + 64

            img_patches = img_batch[:,:,y1:y2,x1:x2] # Batch(img_idx in batch), channels xrange, yrange
            img_patches = img_patches.unsqueeze(dim=1)
            all_patches_list.append(img_patches)

            # print(patch_batch.shape)
    all_patches_tensor = torch.cat(all_patches_list, dim=1)

    patches_per_image = []
    for b in range(all_patches_tensor.shape[0]):
        patches_per_image.append(all_patches_tensor[b])

    patch_batch = torch.cat(patches_per_image, dim = 0)
    return patch_batch
    
def write_csv_stats(csv_path, stats_dict):

    if not os.path.isfile(csv_path):
        with open(csv_path, "w") as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(stats_dict.keys())

    for key, value in stats_dict.items():
        if isinstance(value, float):
            precision = 0.001
            stats_dict[key] =  ((value / precision ) // 1.0 ) * precision

    with open(csv_path, "a") as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(stats_dict.values())

def compute_pre_recall_f1(target, pred):
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred, average='binary')
    return f1

class EarlyStopper:

    def stop(self, epoch, val_loss, val_auc=None,  test_loss=None, test_auc=None, test_ap=None,test_f1=None, train_loss=None,score=None,target=None):
        raise NotImplementedError("Implement this method!")

    def get_best_vl_metrics(self):
        return  self.train_loss, self.val_loss,self.val_auc,self.test_loss,self.test_auc,self.test_ap,self.test_f1, self.best_epoch,self.score,self.target

class Patience(EarlyStopper):

    '''
    Implement common "patience" technique
    '''

    def __init__(self, patience=10, use_train_loss=True):
        self.local_val_optimum = float("inf")
        self.use_train_loss = use_train_loss
        self.patience = patience
        self.best_epoch = -1
        self.counter = -1

        self.train_loss= None
        self.val_loss, self.val_auc, = None, None
        self.test_loss, self.test_auc,self.test_ap,self.test_f1 = None, None,None, None
        self.score, self.target = None, None

    def stop(self, epoch, val_loss, val_auc=None, test_loss=None, test_auc=None, test_ap=None,test_f1=None,train_loss=None,score=None,target=None):
        if self.use_train_loss:
            if train_loss <= self.local_val_optimum:
                self.counter = 0
                self.local_val_optimum = train_loss
                self.best_epoch = epoch
                self.train_loss= train_loss
                self.val_loss, self.val_auc= val_loss, val_auc
                self.test_loss, self.test_auc, self.test_ap,self.test_f1\
                    = test_loss, test_auc, test_ap,test_f1
                self.score, self.target = score,target
                return False
            else:
                self.counter += 1
                return self.counter >= self.patience
        else:
            if val_loss <= self.local_val_optimum:
                self.counter = 0
                self.local_val_optimum = val_loss
                self.best_epoch = epoch
                self.train_loss= train_loss
                self.val_loss, self.val_auc = val_loss, val_auc
                self.test_loss, self.test_auc, self.test_ap,self.test_f1\
                    = test_loss, test_auc, test_ap,test_f1
                self.score, self.target = score, target
                return False
            else:
                self.counter += 1
                return self.counter >= self.patience

# **Dataset**

In [3]:
class ImageNetDataset(Dataset):
    def __init__(self, data_path, is_train, random_seed = 42, target_transform = None, num_classes = None):
        super(ImageNetDataset, self).__init__()
        self.data_path = data_path

        self.is_classes_limited = False

        if num_classes != None:
            self.is_classes_limited = True
            self.num_classes = num_classes

        self.classes = []
        class_idx = 0
        for class_name in os.listdir(data_path):
            if not os.path.isdir(os.path.join(data_path,class_name)):
                continue
            self.classes.append(
               dict(
                   class_idx = class_idx,
                   class_name = class_name,
               ))
            class_idx += 1

            if self.is_classes_limited:
                if class_idx == self.num_classes:
                    break

        if not self.is_classes_limited:
            self.num_classes = len(self.classes)

        self.image_list = []
        for cls in self.classes:
            class_path = os.path.join(data_path, cls['class_name'])
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                self.image_list.append(dict(
                    cls = cls,
                    image_path = image_path,
                    image_name = image_name,
                ))

        self.img_idxes = np.arange(0,len(self.image_list))

        np.random.seed(random_seed)

        if is_train:
            np.random.shuffle(self.img_idxes)
        last_train_sample = int(len(self.img_idxes))
        self.img_idxes = self.img_idxes[:last_train_sample]

    def __len__(self):
        return len(self.img_idxes)

    def __getitem__(self, index):

        img_idx = self.img_idxes[index]
        img_info = self.image_list[img_idx]

        img = Image.open(img_info['image_path'])
        #print('IMG MODE: ' + str(img.mode))

        if img.mode == 'L':
            tr = transforms.Grayscale(num_output_channels=3)
            img = tr(img)

        tr = transforms.ToTensor()
        img1 = tr(img)

        width, height = img.size
        if min(width, height)>IMG_SIZE[0] * 1.5:
            tr = transforms.Resize(int(IMG_SIZE[0] * 1.5))
            img = tr(img)

        width, height = img.size
        if min(width, height)<IMG_SIZE[0]:
            tr = transforms.Resize(IMG_SIZE)
            img = tr(img)

        tr = transforms.RandomCrop(IMG_SIZE)
        img = tr(img)

        tr = transforms.ToTensor()
        img = tr(img)

        if (img.shape[0] != 3):
            img = img[0:3]
            
        #plt.imshow(img.permute(1,2,0))
        #fig, axes = plt.subplots(7,7)

        return dict(image = img, cls = img_info['cls']['class_idx'], class_name = img_info['cls']['class_name'])

    def get_number_of_classes(self):
        return self.num_classes

    def get_number_of_samples(self):
        return self.__len__()

    def get_class_names(self):
        return [cls['class_name'] for cls in self.classes]

    def get_class_name(self, class_idx):
        return self.classes[class_idx]['class_name']


def get_imagenet_datasets(train_path, test_path, train_split = 0.9, num_classes_train = None, num_classes_test = None, random_seed = None):

    if random_seed == None:
        random_seed = int(time.time())
    dataset_train = ImageNetDataset(train_path, is_train = True, random_seed=random_seed, num_classes = num_classes_train)
    trainset_size = int(len(dataset_train)*train_split)
    validset_size = len(dataset_train) - trainset_size
    dataset_train, dataset_valid = random_split(dataset_train, [trainset_size, validset_size])
    dataset_test = ImageNetDataset(test_path, is_train = False, random_seed=random_seed, num_classes = num_classes_test)

    return dataset_train, dataset_valid, dataset_test
    
def get_random_patch_loader(dataset_train):
    return DataLoader(dataset_train, args.num_random_patches, shuffle=True)

# **Resnet Blocks**

* **ResNet18 block**

In [4]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        if downsample:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                #nn.MaxPool2d(2,2);
            )
        else:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = nn.Sequential()
        
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1)
        

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = self.conv1(nn.ReLU()(self.bn1(input)))
        input = self.conv2(nn.ReLU()(self.bn2(input)))
        input = input + shortcut
        return input

# **ResEncoderModel**

* **ResNet18 Model**

In [5]:
class ResNet(nn.Module):
    def __init__(self, in_channels, resblock, repeat, outputs=1024):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        filters = [64, 64, 128, 256] #, 512]

        self.layer1 = nn.Sequential()
        self.layer1.add_module('conv2_1', resblock(filters[0], filters[1], downsample=False))
        for i in range(1, repeat[0]):
                self.layer1.add_module('conv2_%d'%(i+1,), resblock(filters[1], filters[1], downsample=False))

        self.layer2 = nn.Sequential()
        self.layer2.add_module('conv3_1', resblock(filters[1], filters[2], downsample=True))
        for i in range(1, repeat[1]):
                self.layer2.add_module('conv3_%d' % (
                    i+1,), resblock(filters[2], filters[2], downsample=False))

        self.layer3 = nn.Sequential()
        self.layer3.add_module('conv4_1', resblock(filters[2], filters[3], downsample=True))
        for i in range(1, repeat[2]):
            self.layer3.add_module('conv2_%d' % (
                i+1,), resblock(filters[3], filters[3], downsample=False))

        #self.layer4 = nn.Sequential()
        #self.layer4.add_module('conv5_1', resblock(filters[3], filters[4], downsample=True))
        #for i in range(1, repeat[3]):
        #    self.layer4.add_module('conv3_%d'%(i+1,),resblock(filters[4], filters[4], downsample=False))

        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(filters[3], outputs)
        
    def forward(self, input):
        input = self.layer0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        #input = self.layer4(input)
        input = self.gap(input)
        # torch.flatten()
        # https://stackoverflow.com/questions/60115633/pytorch-flatten-doesnt-maintain-batch-size
        input = torch.flatten(input, start_dim=1)
        input = self.fc(input)

        return input

# **DENSENET**

In [6]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        return torch.cat([x, out], 1)

class BottleneckBlock(nn.Module):
    def __init__(self, in_planes, out_planes, dropRate=0.0):
        super(BottleneckBlock, self).__init__()
        inter_planes = out_planes * 4
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
                               padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(inter_planes)
        self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
        out = self.conv2(self.relu(self.bn2(out)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
        return torch.cat([x, out], 1)

class TransitionBlock(nn.Module):
    def __init__(self, in_planes, out_planes, dropRate=0.0):
        super(TransitionBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                               padding=0, bias=False)
        self.droprate = dropRate
    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
        return F.avg_pool2d(out, 2)

class DenseBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
        super(DenseBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
    def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class DenseNet3(nn.Module):
    def __init__(self, depth, growth_rate=12,
                 reduction=0.5, bottleneck=True, dropRate=0.0):
        super(DenseNet3, self).__init__()
        in_planes = 2 * growth_rate
        n = (depth - 4) / 3
        if bottleneck == True:
            n = n/2
            block = BottleneckBlock
        else:
            block = BasicBlock
        n = int(n)
        # 1st conv before any dense block
        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
        in_planes = int(in_planes+n*growth_rate)
        self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
        in_planes = int(math.floor(in_planes*reduction))
        # 2nd block
        self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
        in_planes = int(in_planes+n*growth_rate)
        self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
        in_planes = int(math.floor(in_planes*reduction))
        # 3rd block
        self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
        in_planes = int(in_planes+n*growth_rate)
        self.conv2 = nn.Conv2d(in_channels=342, out_channels=256, kernel_size=3)
        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(256, 1024)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
                
    def forward(self, x):
        
        out = self.conv1(x)
        #print("CONV1: ", out.shape)

        out = self.trans1(self.block1(out))
        #print("TRANSF1: ", out.shape)

        out = self.trans2(self.block2(out))
        #print("BLOCK2: ", out.shape)

        out = self.block3(out)
        #print("BLOCK3: ", out.shape)
        
        out = self.conv2(out)
        #print("CONV2D: ", out.shape)
        
        out = F.avg_pool2d(out, 8)
        #print("avg_pool2d: ", out.shape)

        out = self.gap(out)
        #print("GAP: ", out.shape)

        out = torch.flatten(out, start_dim=1)
        #print("FLATTEN: ", out.shape)

        out = self.fc(out)
        #print("FC: ", out.shape)

        return out

# **Context Prediction Model**

* **CPC with only encoder and predicts the underlying patches**

In [7]:
class ContextPredictionModelWithDir(Module):

    def __init__(self, in_channels, direction):
        super(ContextPredictionModelWithDir, self).__init__()
        # Input will be 1024x7x7
        self.in_channels = in_channels
        self.direction = direction
        
        self.prediction_weights = nn.ModuleList([nn.Linear(
                        in_features = self.in_channels,
                        out_features = self.in_channels,
                    ) for i in range(4)])
        

    # x: encoded patches (2, 1024, 7, 7)
    def forward(self, x): 

        z_patches_loc_list = []
        context_vectors_list = []

        if self.direction == 'DOWN' or self.direction == 'UP':
            for y1 in range(3): #rows
                z_patches_list = []
                for x1 in range(7): #columns
                    if self.direction == 'DOWN':
                        z_patches = x[:, :, y1:y1+1, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1) # Store pixel coordinates
                    else:
                        z_patches = x[:, :, y1+4:y1+5, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1+4,x1) # Store pixel coordinates
                    #print('z_patches: ' + str(z_patches.size()))
                    z_patches_list.append(z_patches) 
                    #print('List: ' + str(len(z_patches_list)))
                    z_patches_loc_list += [z_patches_loc] * len(z_patches)
                    #print('Loc List: ' + str(len(z_patches_loc_list)))

                z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 14, 1024, 1, 1
                #print('Tensor: '+ str(y1) + str(z_patches_tensor.size()))
                z_patches_list.append(z_patches_tensor) # 3
        else:
            for y1 in range(7): #rows
                z_patches_list = []
                for x1 in range(3): #columns
                    if self.direction == 'RIGHT':
                        z_patches = x[:, :, y1:y1+1, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1) # Store pixel coordinates
                    else:
                        z_patches = x[:, :, y1:y1+1, x1+4:x1+5] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1+4) # Store pixel coordinates
                    #print('z_patches: ' + str(z_patches.size()))
                    z_patches_list.append(z_patches) 
                    #print('List: ' + str(len(z_patches_list)))
                    z_patches_loc_list += [z_patches_loc] * len(z_patches)
                    #print('Loc List: ' + str(len(z_patches_loc_list)))

                z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 14, 1024, 1, 1
                #print('Tensor: '+ str(y1) + str(z_patches_tensor.size()))
                z_patches_list.append(z_patches_tensor) # 3
            
        z_patches = torch.cat(z_patches_list, dim = 0) #42, 1024, 1, 1
        #print('Z_patches_vector: ' + str(z_patches.size()))
        z_patches = z_patches.squeeze(dim=3)
        z_patches = z_patches.squeeze(dim=2)

        context_loc_list = torch.tensor(z_patches_loc_list)

        all_predictions = []
        all_loc = []

        for steps_y in range(4):
            predictions = self.prediction_weights[steps_y].forward(z_patches) #42, 1024
            #print('Predictions: ' + str(predictions.size()))
            all_predictions.append(predictions)
            if self.direction == 'DOWN':
                steps_add = torch.tensor([steps_y + 1, 0])
            elif self.direction == 'UP':
                steps_add = torch.tensor([0 - 1 - steps_y, 0])
            elif self.direction == 'RIGHT':
                steps_add = torch.tensor([0, steps_y + 1])
            else:
                steps_add = torch.tensor([0, 0 - 1 - steps_y])
            all_loc.append(context_loc_list + steps_add)

        ret = torch.cat(all_predictions, dim = 0), torch.cat(all_loc, dim = 0)

        return ret

# **Context predictor training**

In [8]:
def run_validation(args, res_encoder_model, context_predictor_model, random_patch_loader, data_loader_train, data_loader_valid):
    res_encoder_model.eval()
    context_predictor_model.eval()
    
    loss_total = 0
    for i, data in enumerate(data_loader_valid):
        img = data['image'].to(args.device)
        patches = get_patch_tensor_from_image_batch(img)
        
        with torch.no_grad():
            patches_return = get_random_patches(random_patch_loader, args.num_random_patches)
            if patches_return['is_data_loader_finished']:
                random_patch_loader = get_random_patch_loader(dataset_train)
            else:
                random_patches = patches_return['patches_tensor'].to(args.device)
            
            enc_patches = res_encoder_model(patches)
            enc_patches = enc_patches.view(1,7,7,-1)
            enc_patches = enc_patches.permute(0,3,1,2)
            
            predictions, locations = context_predictor_model(enc_patches)
            
            enc_random_patches = res_encoder_model(random_patches)
            
            loss = calculate_score_dir(enc_patches, predictions, locations, enc_random_patches)
            loss_total += loss
            
    return loss_total / len(data_loader_valid)         

def run_context_predictor(args, denseNet_model, context_predictor_model, models_store_path):

    print("RUNNING CONTEXT PREDICTOR " +str(context_predictor_model.direction)+ " TRAINING")
    
    #used to create the file where model weights are saved
    prefix = str(context_predictor_model.direction)
    best_encoder = lambda prefix: f"{prefix}_best_res_encoder_weights.pt"
    best_context = lambda prefix: f"{prefix}_best_context_weights.pt"
    
    #upload of datasets
    dataset_train, dataset_valid, dataset_test = get_imagenet_datasets(args.train_image_folder, args.test_image_folder, num_classes_train = args.num_classes_train, num_classes_test = args.num_classes_test)
    
    stats_csv_path = os.path.join(models_store_path, "pred_stats.csv")

    #creation of dataloaders
    random_patch_loader = get_random_patch_loader(dataset_train)
    data_loader_train = DataLoader(dataset_train, args.sub_batch_size, shuffle = True)
    data_loader_valid = DataLoader(dataset_valid, 1, shuffle = True)

    params = list(denseNet_model.parameters()) + list(context_predictor_model.parameters())
    optimizer = torch.optim.Adam(params = params, lr=0.00015)
    #early_stopper = Patience(20, True)

    trigger = 0
    patience = args.patience
    sub_batches_processed = 0
    batch_loss = 0
    sum_batch_loss = 0 
    best_batch_loss = 1e10
    best_valid_loss = 1e10

    z_vect_similarity = dict()
    
    for epoch in range(1, args.num_epochs + 1):
        
        print("RUNNING EPOCH #" + str(epoch))
        denseNet_model.train()
        context_predictor_model.train()
        
        for batch in data_loader_train:

            img_batch = batch['image'].to(args.device)
            patch_batch = get_patch_tensor_from_image_batch(img_batch)
            #print('Patch_batch: ' + str(patch_batch.size()))
            batch_size = len(img_batch)

            # Apply encoder to all the 49 patches of the image (64x64)
            patches_encoded = denseNet_model.forward(patch_batch) #98, 1024
            #print("PATCHES_ENCODED SIZE AFTER RES_ENCODER_MODEL", patches_encoded.shape)
            #print('Patch_encoded: ' + str(patches_encoded.size()))
            patches_encoded = patches_encoded.view(batch_size, 7,7,-1) #reshape 2, 7, 7, 1024
            #print('Reshape: ' + str(patches_encoded.size()))
            patches_encoded = patches_encoded.permute(0,3,1,2) #2, 1024, 7, 7
            #print('Permute: ' + str(patches_encoded.size()))

            for i in range(2):
                patches_return = get_random_patches(random_patch_loader, args.num_random_patches)
                if patches_return['is_data_loader_finished']:
                    random_patch_loader = get_random_patch_loader(dataset_train)
                else:
                    random_patches = patches_return['patches_tensor'].to(args.device)

            # enc_random_patches = resnet_encoder.forward(random_patches).detach()
            # Apply encoder to few rendom patches
            enc_random_patches = denseNet_model.forward(random_patches)

            # Apply context_predictor to encoded patches
            predictions, locations = context_predictor_model.forward(patches_encoded) #112, 1024
            #print('Predictions: ' + str(predictions.size())) 
            losses = []

            for b in range(len(predictions)//batch_size): #batch_size = 2

                b_idx_start = b*batch_size
                b_idx_end = (b+1)*batch_size

                p_y = locations[b_idx_start][0]
                p_x = locations[b_idx_start][1]

                # Encoded patches on the same position of the predictions (Z_i+k,j)
                target = patches_encoded[:,:,p_y,p_x]
                # Predicted patches done by context predictor (Zcap_i+k,j = W_k * c_i,j)
                pred = predictions[b_idx_start:b_idx_end] #2,1024

                #dot_norm_val = dot_norm_exp(pred.detach().to('cpu'), target.detach().to('cpu'))
                #euc_loss_val = norm_euclidian(pred.detach().to('cpu'), target.detach().to('cpu'))

                # Mul between predictions and encoded patches (Zcap_i+k,j * Z_i+k,j)
                good_term_dot = dot(pred, target) #dot_norm_val #dot(pred, target)
                dot_terms = [torch.unsqueeze(good_term_dot,dim=0)]

                for random_patch_idx in range(args.num_random_patches):
                    # Mul between predictions and ancoded random patches (Zcap_i+k,j * Z_l)
                    #bad_term_dot = dot_norm_exp(pred.detach().to('cpu'), enc_random_patches[random_patch_idx:random_patch_idx+1].detach().to('cpu'))
                    bad_term_dot = dot(pred, enc_random_patches[random_patch_idx:random_patch_idx+1])
                    dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0))

                log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0)
                losses.append(-log_softmax[0,])

            loss = torch.mean(torch.cat(losses))
            loss.backward()

            sub_batches_processed += img_batch.shape[0]
            batch_loss += loss.detach().to('cpu')
            sum_batch_loss += torch.sum(torch.cat(losses).detach().to('cpu'))

            if sub_batches_processed >= args.batch_size:

                optimizer.step()
                optimizer.zero_grad()

                print(f"{datetime.datetime.now()} Loss: {batch_loss}")
                print(f"{datetime.datetime.now()} SUM Loss: {sum_batch_loss}")

                torch.save(denseNet_model.state_dict(), os.path.join(models_store_path, "last_res_encoder_weights.pt"))
                torch.save(context_predictor_model.state_dict(), os.path.join(models_store_path, "last_context_predictor_weights.pt"))
                
                if best_batch_loss > batch_loss:
                    best_batch_loss = batch_loss
                    best_encoder_model = denseNet_model #.state_dict()
                    best_context_model = context_predictor_model #.state_dict()
                    torch.save(denseNet_model.state_dict(), os.path.join(models_store_path, best_encoder(prefix)))
                    torch.save(context_predictor_model.state_dict(), os.path.join(models_store_path, best_context(prefix)))

                for key, cos_similarity_tensor in z_vect_similarity.items():
                    print(f"Mean cos_sim for class {key} is {cos_similarity_tensor.mean()} . Number: {cos_similarity_tensor.size()}")

                z_vect_similarity = dict()

                stats = dict(
                    batch_loss = batch_loss,
                    sum_batch_loss = sum_batch_loss
                )
                write_csv_stats(stats_csv_path, stats)

                sub_batches_processed = 0
                batch_loss = 0
                sum_batch_loss = 0
                
        # Early stopping
        if epoch % 5 == 0:
            valid_loss = run_validation(args, best_encoder_model, best_context_model, random_patch_loader, data_loader_train, data_loader_valid)
            print('Validation Loss:' +str(valid_loss))
            
            if valid_loss > best_valid_loss:
                trigger += 1
                if trigger >= patience:
                    n = epoch - patience
                    print('Early Stopping! Find best epoch: ' + str(n))
                    torch.save(real_best_encoder.state_dict(), os.path.join(models_store_path, best_encoder(prefix)))
                    torch.save(real_best_context.state_dict(), os.path.join(models_store_path, best_context(prefix)))
                    return
            else:
                trigger = 0
                real_best_encoder = best_encoder_model
                real_best_context = best_context_model
                best_valid_loss = valid_loss

# **Anomaly detection evaluation**

In [9]:
def calculate_score_dir(enc_patches, predictions, locations, enc_random_patches):
    losses = []

    for b in range(len(predictions)):
        
        p_y = locations[b][0]
        p_x = locations[b][1]

        target = enc_patches[:,:,p_y,p_x]
        pred = predictions[b]

        #dot_norm_val = dot_norm_exp(x_tk.detach().to(args.device), x_t.detach().to(args.device))
        #euc_loss_val = norm_euclidian(x_tk.detach().to(args.device), x_t.detach().to(args.device))

        good_term_dot = dot(pred, target)#dot_norm_val #dot(pred, target)
        dot_terms = [torch.unsqueeze(good_term_dot,dim=0)]

        for random_patch_idx in range(args.num_random_patches):
            #bad_term_dot = dot_norm_exp(pred.detach().to('cpu'), enc_random_patches[random_patch_idx:random_patch_idx+1].detach().to('cpu'))
            bad_term_dot = dot(pred, enc_random_patches[random_patch_idx:random_patch_idx+1])
            dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0))

        log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0)
        losses.append(-log_softmax[0,])
        # losses.append(-torch.log(good_term/divisor))

    loss = torch.mean(torch.cat(losses))
    return loss

def run_anomaly_evaluation(args, res_encoder_model_list, context_model_list, models_store_path):

    print("RUNNING ANOMALY DETECTION")

    dataset_train, dataset_valid, dataset_test = get_imagenet_datasets(args.train_image_folder, args.test_image_folder, train_split=1, num_classes_train = args.num_classes_train, num_classes_test = args.num_classes_test)
    data_loader_test = DataLoader(dataset_test, 1, shuffle = False)
    NUM_TEST_SAMPLES = dataset_test.get_number_of_samples()
    print(NUM_TEST_SAMPLES)

    random_patch_loader = get_random_patch_loader(dataset_train) 
    
    for i, res in enumerate(res_encoder_model_list):
        res.eval()
        context_model_list[i].eval()
    
    score_all = []
    label_all = []
    data_list = []
    
    for i, data in enumerate(data_loader_test):
        data_list.append(data)
        class_name = data['class_name'][0]
        #print('Class: ' +str(class_name))
        
        img = data['image'].to(args.device)
        patches = get_patch_tensor_from_image_batch(img)
        
        image_scores = []
        enc_patches = []
        enc_random_patches = []
        pred_list = []
        location_list = []
        with torch.no_grad():
            # Get random patches from images not anomalous
            patches_return = get_random_patches(random_patch_loader, args.num_random_patches)
            if patches_return['is_data_loader_finished']:
                random_patch_loader = get_random_patch_loader(dataset_train)
            else:
                random_patches = patches_return['patches_tensor'].to(args.device)
                
            for r, res in enumerate(res_encoder_model_list):
                # Encode patches of test image
                temp_patches = res(patches) #49, 1024
                temp_patches = temp_patches.view(1,7,7,-1) #reshape 1, 7, 7, 1024
                temp_patches = temp_patches.permute(0,3,1,2) #1, 1024, 7, 7
                enc_patches.append(temp_patches)
                
                # Encode patches of random images
                temp_random_patches = res(random_patches) #49, 1024
                enc_random_patches.append(temp_random_patches)
                
                # Predictions from encoded patches
                temp_pred, temp_locations = context_model_list[r](temp_patches) #112, 1024
                pred_list.append(temp_pred)
                location_list.append(temp_locations)    
           
            print('Image #'+str(i))
            for j, enc_p in enumerate(enc_patches):
                score = calculate_score_dir(enc_p, pred_list[j], location_list[j], enc_random_patches[j])
                print('Score ' +str(j)+ ': ' + str(score.item()))
                image_scores.append(score)
            
            if class_name == 'good':
                label_all.append(0)
            else:
                label_all.append(1)
            
            avg_score = sum(image_scores) / len(image_scores) #Media delle loss dei 4 modelli
            #max_score = torch.max(torch.cat(image_scores)) Loss massima tra i 4 modelli
            score_all.append(avg_score)

    
    score_all = [s.cpu().numpy() for s in score_all]
    score_all = np.vstack(score_all)
    score_all = np.concatenate(score_all)
    
    # Compute threshold -> predictions
    normal_ratio = sum(1 for a in label_all if a == 0) / len(label_all)
    threshold = np.percentile(score_all, 100 * normal_ratio)
    predictions = np.zeros(len(score_all))
    predictions[score_all > threshold] = 1
    
    with open('{0}/Grid.csv'.format(models_store_path), mode='w') as csv_file:
        fieldnames = ['Class_name', 'Score', 'Anomaly', 'AUC', 'F1', 'Average precision']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()
        for i, data in enumerate(data_list):
            print("IMAGE #" + str(i))
            print("Class name: " + str(data['class_name'][0]))
            print("Average score: " + str(score_all[i]))
            print("Predictions: " + str(predictions[i]))
            
            writer.writerow({
                'Class_name': data['class_name'][0],
                'Score': score_all[i],
                'Anomaly': predictions[i]})
    
        auc = roc_auc_score(label_all, score_all)
        f1 = compute_pre_recall_f1(label_all, predictions)
        ap = average_precision_score(label_all, score_all)
        print('AUC: ' + str(auc) + '\nF1: ' + str(f1) + '\nAverage Precision: ' + str(ap))
        writer.writerow({
            'AUC': auc,
            'F1': f1,
            'Average precision': ap})             
        
 #       import torchvision.transforms as T
 #       from PIL import Image
        # define a transform to convert a tensor to PIL image
 #       transform = T.ToPILImage()
        # convert the tensor to PIL image using above transform
 #       img = img.permute(1,2,3,0)
 #       img = img.view(3, 256, 256)
 #       img = transform(img)
 #       plt.imshow(img)

# **MAIN**

* **Contrastive Predictive Coding MAIN**
> Main for training the CPC composed by encoder and autoregressive

* **Anomaly Detection MAIN**
> Main for detecting anomalies in pictures using Contrastive Predictive Coding

In [10]:
IMG_SIZE = (256,256)
parser = argparse.ArgumentParser(description='Contrastive predictive coding params')

# mode = 'train_encoder_context_prediction'
# mode = 'anomaly_evaluation'
parser.add_argument('-mode', default='anomaly_evaluation' , type=str)
# chest-xray-pneumonia/chest_xray/train' 
# mvtec-ad/bottle/train'          num_class 1         
parser.add_argument('-train_image_folder', default='../input/mvtec-ad/grid/train', type=str)
parser.add_argument('-num_classes_train', default=1, type=int)
# chest-xray-pneumonia/chest_xray/train'  num_class 2
# mvtec-ad/bottle/test'                   num_class 4
parser.add_argument('-test_image_folder', default='../input/mvtec-ad/bottle/test', type=str)
parser.add_argument('-num_classes_test', default=4, type=int)
parser.add_argument('-batch_size', default=16, type=int)
parser.add_argument('-sub_batch_size', default=2, type=int)
parser.add_argument('-num_random_patches', default=15, type=int)
parser.add_argument('-num_epochs', default=30, type=int)
parser.add_argument('-patience', default=10, type=int)

parser.add_argument('-layers', default=100, type=int,
                    help='total number of layers (default: 100)')
parser.add_argument('-growth', default=12, type=int,
                    help='number of new channels per layer (default: 12)')
parser.add_argument('-reduce', default=0.5, type=float,
                    help='compression rate in transition stage (default: 0.5)')
parser.add_argument('-no-bottleneck', dest='bottleneck', action='store_false',
                    help='To not use bottleneck block')
parser.add_argument('-droprate', default=0, type=float,
                    help='dropout probability (default: 0.0)')
# cpu or cuda
parser.add_argument('-device', default='cuda', type=str)

args, args_other = parser.parse_known_args()

print(f"Running CPC with args {args}")

Z_DIMENSIONS = 1024
DIRECTIONS = ['DOWN', 'UP', 'RIGHT', 'LEFT']

stored_models_root_path = "trained_models"
if not os.path.isdir(stored_models_root_path):
    os.mkdir(stored_models_root_path)
stored_eval_root_path = "evaluation"
if not os.path.isdir(stored_eval_root_path):
    os.mkdir(stored_eval_root_path)

if args.mode == 'train_encoder_context_prediction':
    
    for i, direc in enumerate(DIRECTIONS):
        denseNet_model = None
        context_predictor_model = None

        # create model
        denseNet_model = DenseNet3(args.layers, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate).to(args.device)
        
        #res_encoder_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device)
        context_predictor_model = ContextPredictionModelWithDir(in_channels=Z_DIMENSIONS, direction=direc).to(args.device)
        
        # Models training
        model_store_folder = get_next_model_folder(direc, stored_models_root_path)
        os.mkdir(model_store_folder)
        run_context_predictor(args, denseNet_model, context_predictor_model, model_store_folder)

if args.mode == 'anomaly_evaluation':
    # Evaluation
    denseNet_model_list = []
    context_model_list = []
    
    for i, direc in enumerate(DIRECTIONS):
        denseNet_model = None
        context_predictor_model = None
        res_encoder_weights_path = ''
        context_weights_path = ''
        #res_encoder_model = ResEncoderModel().to(args.device)
        # ResNet18 v2 up to the third residual block
        denseNet_model = DenseNet3(args.layers, args.growth, reduction=args.reduce,
                             bottleneck=args.bottleneck, dropRate=args.droprate).to(args.device)
        context_predictor_model = ContextPredictionModelWithDir(in_channels=Z_DIMENSIONS, direction=direc).to(args.device)
        
        if direc == 'DOWN':
            res_encoder_weights_path = '/kaggle/input/outputtrain/DOWN_best_res_encoder_weights.pt'
            context_weights_path = '/kaggle/input/outputtrain/DOWN_best_context_weights.pt'
        elif direc == 'UP':
            res_encoder_weights_path = '/kaggle/input/outputtrain/UP_best_res_encoder_weights.pt'
            context_weights_path = '/kaggle/input/outputtrain/UP_best_context_weights.pt'
        elif direc == 'RIGHT':
            res_encoder_weights_path = '/kaggle/input/outputtrain/RIGHT_best_res_encoder_weights.pt'
            context_weights_path = '/kaggle/input/outputtrain/RIGHT_best_context_weights.pt'
        else:
            res_encoder_weights_path = '/kaggle/input/outputtrain/LEFT_best_res_encoder_weights.pt'
            context_weights_path = '/kaggle/input/outputtrain/LEFT_best_context_weights.pt'
        
        print(f"Loading res encoder {direc} weights from {res_encoder_weights_path}")
        print(f"Loading context {direc} weights from {context_weights_path}")
        
        # Load weights in the models
        denseNet_model.load_state_dict(torch.load(res_encoder_weights_path))
        context_predictor_model.load_state_dict(torch.load(context_weights_path))
        
        # Encoder models and context models lists
        denseNet_model_list.append(denseNet_model)
        context_model_list.append(context_predictor_model)
        
    run_anomaly_evaluation(args, denseNet_model_list, context_model_list, stored_eval_root_path)

Running CPC with args Namespace(batch_size=16, bottleneck=True, device='cuda', droprate=0, growth=12, layers=100, mode='anomaly_evaluation', num_classes_test=4, num_classes_train=1, num_epochs=30, num_random_patches=15, patience=10, reduce=0.5, sub_batch_size=2, test_image_folder='../input/mvtec-ad/bottle/test', train_image_folder='../input/mvtec-ad/grid/train')
Loading res encoder DOWN weights from /kaggle/input/outputtrain/DOWN_best_res_encoder_weights.pt
Loading context DOWN weights from /kaggle/input/outputtrain/DOWN_best_context_weights.pt
Loading res encoder UP weights from /kaggle/input/outputtrain/UP_best_res_encoder_weights.pt
Loading context UP weights from /kaggle/input/outputtrain/UP_best_context_weights.pt
Loading res encoder RIGHT weights from /kaggle/input/outputtrain/RIGHT_best_res_encoder_weights.pt
Loading context RIGHT weights from /kaggle/input/outputtrain/RIGHT_best_context_weights.pt
Loading res encoder LEFT weights from /kaggle/input/outputtrain/LEFT_best_res_enc