In [27]:
#!pip install torchsummary
import torch
from torch import nn
from torch.nn import Module
import datetime
import os
import argparse
import random
import csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import time
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, average_precision_score
import torch.nn.functional as F
#from torchsummary import summary

# **Helper functions**

In [28]:
def dot_norm_exp(a,b):
    dot = torch.sum(a * b, dim=1)
    aa = torch.sum((a**2),dim=1)**0.5
    bb = torch.sum((b**2),dim=1)**0.5
    dot_norm = dot/(aa*bb)
    ret = torch.exp(dot_norm)
    return ret

def dot_norm(a,b):
    dot = torch.sum(a * b, dim=1)
    aa = torch.sum((a**2),dim=1)**0.5
    bb = torch.sum((b**2),dim=1)**0.5
    dot_norm = dot/(aa*bb)
    return dot_norm

def dot(a,b):
    dot = torch.sum(a * b, dim=1)
    return dot

def norm_euclidian(a,b):
    aa = (torch.sum((a**2),dim=1)**0.5).unsqueeze(dim=1)
    bb = (torch.sum((b**2),dim=1)**0.5).unsqueeze(dim=1)
    return (torch.sum(((a/aa-b/bb)**2),dim=1)**0.5)

def get_next_model_folder(prefix, path = ''):

    model_folder = lambda prefix, run_idx: f"{prefix}_model_run_{run_idx}"

    run_idx = 1
    while os.path.isdir(os.path.join(path, model_folder(prefix, run_idx))):
        run_idx += 1

    model_path = os.path.join(path, model_folder(prefix, run_idx))
    print(f"STARTING {prefix} RUN {run_idx}! Storing the models at {model_path}")

    return model_path

def get_random_patches(random_patch_loader, num_random_patches):

        is_data_loader_finished = False

        try:
            img_batch = next(iter(random_patch_loader))['image']
        except StopIteration:
            is_data_loader_finished = True
            # random_patch_loader = DataLoader(dataset_train, num_random_patches, shuffle=True)

        if len(img_batch) < num_random_patches:
            is_data_loader_finished = True

        patches = []

        for i in range(num_random_patches):
            x = random.randint(0,6)
            y = random.randint(0,6)

            patches.append(img_batch[i:i+1,:,x*32:x*32+64,y*32:y*32+64])

            # plt.imshow(np.transpose(patches[-1][0],(1,2,0)))
            # plt.show()

        patches_tensor = torch.cat(patches, dim=0)

        return dict(
            patches_tensor = patches_tensor,
            is_data_loader_finished = is_data_loader_finished)

# Tell how many parameters are on the model
def inspect_model(model):
    param_count = 0
    for param_tensor_str in model.state_dict():
        tensor_size = model.state_dict()[param_tensor_str].size()
        print(f"{param_tensor_str} size {tensor_size} = {model.state_dict()[param_tensor_str].numel()} params")
        param_count += model.state_dict()[param_tensor_str].numel()

    print(f"Number of parameters: {param_count}")
    
#in img_batch ho due immagini
def get_patch_tensor_from_image_batch(img_batch):

    patch_batch = []
    #here I will put 25 patches of size 256x256
    all_patches_list_256x256 = []
    #then from each of the above patches of size 256x256 I will create 49 patches of size 64x64
    #total will be 25*49=1.225 patches per image of size 64x64
    all_patches_list = []
    all_patches_list_temp = []

    for y_patch in range(5):
        for x_patch in range(5):

            y1 = y_patch * 128
            y2 = y1 + 256

            x1 = x_patch * 128
            x2 = x1 + 256

            img_patches = img_batch[:,:,y1:y2,x1:x2] # Batch(img_idx in batch), channels xrange, yrange
            #img_patches = img_patches.unsqueeze(dim=1)
            all_patches_list_256x256.append(img_patches)

            # print(patch_batch.shape)
    i = 0
    for patch in all_patches_list_256x256:
        i += 1
        for y_patch in range(7):
            for x_patch in range(7):

                y1 = y_patch * 32
                y2 = y1 + 64

                x1 = x_patch * 32
                x2 = x1 + 64

                img_patches = patch[:,:,y1:y2,x1:x2] # Batch(img_idx in batch), channels xrange, yrange
                #print("sub patch of big patch number", i, " in position", x_patch, y_patch, "  x1 = ", x1, "  x2 = ", x2,  "  y1 = ",y1, "  y2 = ", y2)
                img_patches = img_patches.unsqueeze(dim=1)
                all_patches_list_temp.append(img_patches)
        
        all_patches_list.append(all_patches_list_temp)
        #print("ALL PATCHES LIST SHAPE", len(all_patches_list))
        #print("ALL PATCHES LIST [0] SHAPE", len(all_patches_list[0]))

        all_patches_list_temp = []
            # print(patch_batch.shape)
    
    all_patches_tensor_temp = []
    all_patches_tensor = []

    
    for i in range(25):
        all_patches_tensor_temp = torch.cat(all_patches_list[i], dim=1)
        all_patches_tensor.append(all_patches_tensor_temp)

        
    patches_per_image = []
    patches_per_image_temp = []
    
    for i in range(25):
        for b in range(all_patches_tensor[i].shape[0]): #2
            patches_per_image_temp.append(all_patches_tensor[i][b])
        patches_per_image.append(patches_per_image_temp)
        patches_per_image_temp = []

    patch_batch_temp = []
    for i in range(25):
        patch_batch_temp = torch.cat(patches_per_image[i], dim = 0)
        patch_batch.append(patch_batch_temp)

    return patch_batch
    
def write_csv_stats(csv_path, stats_dict):

    if not os.path.isfile(csv_path):
        with open(csv_path, "w") as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(stats_dict.keys())

    for key, value in stats_dict.items():
        if isinstance(value, float):
            precision = 0.001
            stats_dict[key] =  ((value / precision ) // 1.0 ) * precision

    with open(csv_path, "a") as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(stats_dict.values())

def compute_pre_recall_f1(target, pred):
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred, average='binary')
    return f1

class EarlyStopper:

    def stop(self, epoch, val_loss, val_auc=None,  test_loss=None, test_auc=None, test_ap=None,test_f1=None, train_loss=None,score=None,target=None):
        raise NotImplementedError("Implement this method!")

    def get_best_vl_metrics(self):
        return  self.train_loss, self.val_loss,self.val_auc,self.test_loss,self.test_auc,self.test_ap,self.test_f1, self.best_epoch,self.score,self.target

class Patience(EarlyStopper):

    '''
    Implement common "patience" technique
    '''

    def __init__(self, patience=10, use_train_loss=True):
        self.local_val_optimum = float("inf")
        self.use_train_loss = use_train_loss
        self.patience = patience
        self.best_epoch = -1
        self.counter = -1

        self.train_loss= None
        self.val_loss, self.val_auc, = None, None
        self.test_loss, self.test_auc,self.test_ap,self.test_f1 = None, None,None, None
        self.score, self.target = None, None

    def stop(self, epoch, val_loss, val_auc=None, test_loss=None, test_auc=None, test_ap=None,test_f1=None,train_loss=None,score=None,target=None):
        if self.use_train_loss:
            if train_loss <= self.local_val_optimum:
                self.counter = 0
                self.local_val_optimum = train_loss
                self.best_epoch = epoch
                self.train_loss= train_loss
                self.val_loss, self.val_auc= val_loss, val_auc
                self.test_loss, self.test_auc, self.test_ap,self.test_f1\
                    = test_loss, test_auc, test_ap,test_f1
                self.score, self.target = score,target
                return False
            else:
                self.counter += 1
                return self.counter >= self.patience
        else:
            if val_loss <= self.local_val_optimum:
                self.counter = 0
                self.local_val_optimum = val_loss
                self.best_epoch = epoch
                self.train_loss= train_loss
                self.val_loss, self.val_auc = val_loss, val_auc
                self.test_loss, self.test_auc, self.test_ap,self.test_f1\
                    = test_loss, test_auc, test_ap,test_f1
                self.score, self.target = score, target
                return False
            else:
                self.counter += 1
                return self.counter >= self.patience

# **Dataset**

In [29]:
class ImageNetDataset(Dataset):
    def __init__(self, data_path, is_train, train_split = 0.9, test_split = 1, random_seed = 42, target_transform = None, num_classes = None):
        super(ImageNetDataset, self).__init__()
        self.data_path = data_path

        self.is_classes_limited = False

        if num_classes != None:
            self.is_classes_limited = True
            self.num_classes = num_classes

        self.classes = []
        class_idx = 0
        for class_name in os.listdir(data_path):
            if not os.path.isdir(os.path.join(data_path,class_name)):
                continue
            self.classes.append(
               dict(
                   class_idx = class_idx,
                   class_name = class_name,
               ))
            class_idx += 1

            if self.is_classes_limited:
                if class_idx == self.num_classes:
                    break

        if not self.is_classes_limited:
            self.num_classes = len(self.classes)

        self.image_list = []
        for cls in self.classes:
            class_path = os.path.join(data_path, cls['class_name'])
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                self.image_list.append(dict(
                    cls = cls,
                    image_path = image_path,
                    image_name = image_name,
                ))

        self.img_idxes = np.arange(0,len(self.image_list))

        np.random.seed(random_seed)

        if is_train:
            np.random.shuffle(self.img_idxes)
            last_train_sample = int(len(self.img_idxes) * train_split)
            self.img_idxes = self.img_idxes[:last_train_sample]
            #if is_val:
            #    self.img_idxes = self.img_idxes[last_train_sample:]
        else:
            last_train_sample = int(len(self.img_idxes) * test_split)
            self.img_idxes = self.img_idxes[last_train_sample:]

    def __len__(self):
        return len(self.img_idxes)

    def __getitem__(self, index):

        img_idx = self.img_idxes[index]
        img_info = self.image_list[img_idx]

        img = Image.open(img_info['image_path'])
        #print('IMG MODE: ' + str(img.mode))

        if img.mode == 'L':
            tr = transforms.Grayscale(num_output_channels=3)
            img = tr(img)

        tr = transforms.ToTensor()
        img1 = tr(img)

        width, height = img.size
        if min(width, height)>IMG_SIZE[0] * 1.5:
            tr = transforms.Resize(int(IMG_SIZE[0] * 1.5))
            img = tr(img)

        width, height = img.size
        if min(width, height)<IMG_SIZE[0]:
            tr = transforms.Resize(IMG_SIZE)
            img = tr(img)

        tr = transforms.RandomCrop(IMG_SIZE)
        img = tr(img)

        tr = transforms.ToTensor()
        img = tr(img)

        if (img.shape[0] != 3):
            img = img[0:3]
            
        #plt.imshow(img.permute(1,2,0))
        #fig, axes = plt.subplots(7,7)

        return dict(image = img, cls = img_info['cls']['class_idx'], class_name = img_info['cls']['class_name'])

    def get_number_of_classes(self):
        return self.num_classes

    def get_number_of_samples(self):
        return self.__len__()

    def get_class_names(self):
        return [cls['class_name'] for cls in self.classes]

    def get_class_name(self, class_idx):
        return self.classes[class_idx]['class_name']


def get_imagenet_datasets(train_path, test_path, train_split = 0.9, test_split = 1, num_classes_train = None, num_classes_test = None, random_seed = None):

    if random_seed == None:
        random_seed = int(time.time())
    dataset_train = ImageNetDataset(train_path, is_train = True, random_seed=random_seed, num_classes = num_classes_train, train_split=train_split)
    dataset_test = ImageNetDataset(test_path, is_train = False, random_seed=random_seed, num_classes = num_classes_test, train_split=train_split, test_split=test_split)

    return dataset_train, dataset_test
    
def get_random_patch_loader(dataset_train):
    return DataLoader(dataset_train, args.num_random_patches, shuffle=True)

# **Resnet18-v2 Model**

In [30]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        if downsample:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                #nn.MaxPool2d(2,2);
            )
        else:
            self.conv1 = nn.Conv2d(
                in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = nn.Sequential()
        
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1)
        

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = self.conv1(nn.ReLU()(self.bn1(input)))
        input = self.conv2(nn.ReLU()(self.bn2(input)))
        input = input + shortcut
        return input

In [31]:
class ResNet(nn.Module):
    def __init__(self, in_channels, resblock, repeat, outputs=1024):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        filters = [64, 64, 128, 256] #, 512]

        self.layer1 = nn.Sequential()
        self.layer1.add_module('conv2_1', resblock(filters[0], filters[1], downsample=False))
        for i in range(1, repeat[0]):
                self.layer1.add_module('conv2_%d'%(i+1,), resblock(filters[1], filters[1], downsample=False))

        self.layer2 = nn.Sequential()
        self.layer2.add_module('conv3_1', resblock(filters[1], filters[2], downsample=True))
        for i in range(1, repeat[1]):
                self.layer2.add_module('conv3_%d' % (
                    i+1,), resblock(filters[2], filters[2], downsample=False))

        self.layer3 = nn.Sequential()
        self.layer3.add_module('conv4_1', resblock(filters[2], filters[3], downsample=True))
        for i in range(1, repeat[2]):
            self.layer3.add_module('conv2_%d' % (
                i+1,), resblock(filters[3], filters[3], downsample=False))

        #self.layer4 = nn.Sequential()
        #self.layer4.add_module('conv5_1', resblock(filters[3], filters[4], downsample=True))
        #for i in range(1, repeat[3]):
        #    self.layer4.add_module('conv3_%d'%(i+1,),resblock(filters[4], filters[4], downsample=False))

        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(filters[3], outputs)
        
    def forward(self, input):
        input = self.layer0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        #input = self.layer4(input)
        input = self.gap(input)
        # torch.flatten()
        # https://stackoverflow.com/questions/60115633/pytorch-flatten-doesnt-maintain-batch-size
        input = torch.flatten(input, start_dim=1)
        input = self.fc(input)

        return input

In [32]:
class ResEncoderModel(Module):
    def __init__(self):
        super(ResEncoderModel, self).__init__()

        # Input is 3 x 64 x 64
        # prep -> 256 x 32 x 32

        self.conv_blocks = [10,10,10] #256x32x32 -> 512x16x16 -> 1024x8x8
        self.num_blocks = len(self.conv_blocks)
        self.start_channels = 256


        self.prep = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = self.start_channels,
                kernel_size = 7,
                stride = 1, #Let's not reduce twice
                padding = 3
            ),
            nn.BatchNorm2d(
                num_features = self.start_channels,
            ),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size = 3,
                stride = 2,
                padding = 1
            )
        )
        # Output 256x 32 x 32
        current_channels = self.start_channels

        self.resnet_blocks = nn.ModuleList()

        for block_idx, conv_block_num in enumerate(self.conv_blocks):
            resnet_block = nn.Sequential()

            for conv_block_idx in range(conv_block_num):

                is_downsampling_block = False

                if block_idx > 0 and conv_block_idx == 0:
                    is_downsampling_block = True

                resnet_block.add_module(
                    f'conv_{conv_block_idx}',
                    ResNetBottleneckBlock(
                        in_channels_block = current_channels,
                        is_downsampling_block = is_downsampling_block
                    )
                )

                if is_downsampling_block:
                    current_channels *= 2

            self.resnet_blocks.append(resnet_block)

        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)


    def forward(self, x):

        x = self.prep.forward(x)
        #print(f'shape after prep {x.shape}')
        for i in range(self.num_blocks):
            x = self.resnet_blocks[i].forward(x)
            #print(f'shape after resnet_block {i} {x.shape}')

        #print(f'shape after resnet {x.shape}')
        x = self.avg_pool.forward(x)
        x = torch.squeeze(x, dim=3)
        x = torch.squeeze(x, dim=2)
        ##print(f'shape after avg_pool {x.shape}')

        return x

# **Context Prediction Model**

* **CPC with PixelCNN 3x3**

In [33]:
class ContextPredictionModel(Module):

    def __init__(self, in_channels):
        super(ContextPredictionModel, self).__init__()

        self.in_channels = in_channels

        # Input will be 1024x7x7

        # Two sets of convolutional context networks - one for vertical, one for horizontal agregation.

        # Prediction 3 steps ahead. So I will have 8 outputs.
        # [0,2:6] predict->[3,4,5:6],[1,3:6] predict->[4,5,6:6]
        # [4,6:6] predict->[3,2,1:6],[3,5:6] predict->[2,1,0:6]

        # [6:0,2] predict->[6:3,4,5],[6:1,3] predict->[6:4,5,6]
        # [6:4,6] predict->[6:3,2,1],[6:3,5] predict->[6:2,1,0]

        self.context_layers = 3
        self.context_conv = nn.Sequential()

        for layer_idx in range(self.context_layers):
            self.context_conv.add_module(f'batch_norm_{layer_idx}',nn.BatchNorm2d(self.in_channels)),
            self.context_conv.add_module(f'relu_{layer_idx}',nn.ReLU())
            self.context_conv.add_module(
                f'conv2d_{layer_idx}',
                nn.Conv2d(
                    in_channels = self.in_channels,
                    out_channels = self.in_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0
                )
            )

        self.context_conv.add_module(
            'adaptive_avg_pool',
            nn.AdaptiveAvgPool2d(output_size=1)
        )


        # Y direction predictions, X direction predictions

        self.prediction_weights = nn.ModuleList([nn.ModuleList() for i in range(4)])

        # Create linear layers in two directions
        for direction in range(2):
            for prediction_steps in range(3):
                self.prediction_weights[direction].append(
                    nn.Linear(
                        in_features = self.in_channels,
                        out_features = self.in_channels,
                    )
                )


    # x: encoded patches (2, 1024, 7, 7)
    def forward(self, x): 

        z_patches_list = []
        z_patches_loc_list = []

        for y1 in range(5):
            for x1 in range(5):
                y2 = y1 + 2
                x2 = x1 + 2

                z_patches = x[:,:,y1:y2+1,x1:x2+1] #2, 1024, 3, 3
                #print('z_patches: ' + str(z_patches.size()))
                z_patches_loc = (y1+1,x1+1) # Store middle of the 3x3 square

                z_patches_list.append(z_patches) # itera fino 25
                #print('List: ' + str(len(z_patches_list)))
                z_patches_loc_list += [z_patches_loc] * len(z_patches)
                print('Loc List: ' + str(len(z_patches_loc_list)))

        z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 50, 1024, 3, 3
        #print('Tensor: ' + str(z_patches_tensor.size()))

        # Apply context model to 
        context_vectors = self.context_conv.forward(z_patches_tensor) #50, 1024, 1, 1
        #print('Context_vector: ' + str(context_vectors.size()))

        context_vectors = context_vectors.squeeze(dim=3) #50,1024,1
        #print('Context_vector: ' + str(context_vectors.size()))
        context_vectors = context_vectors.squeeze(dim=2) #50,1024
        #print('Context_vector: ' + str(context_vectors.size()))

        context_vectors_for_yp = []
        context_loc_for_yp = []

        context_vectors_for_xp = []
        context_loc_for_xp = []

        # Questo gli serve perchè lui fa delle predictions in due diverse dimensioni, per cui crea due context vectors
        # diversi a seconda della direzione
        for v_idx in range(len(context_vectors)): # 0-49
            y3 = z_patches_loc_list[v_idx][0]
            x3 = z_patches_loc_list[v_idx][1]

            if y3 == 1 or y3 == 2:
                context_vectors_for_yp.append(context_vectors[v_idx:v_idx+1]) #20 finali
                context_loc_for_yp.append(z_patches_loc_list[v_idx])

            if x3 == 1 or x3 == 2:
                context_vectors_for_xp.append(context_vectors[v_idx:v_idx+1]) #20 finali
                context_loc_for_xp.append(z_patches_loc_list[v_idx])

        print('Context_vectors_for_yp: ' + str(len(context_vectors_for_yp)))
        print('Context_vectors_for_xp: ' + str(len(context_vectors_for_xp)))
        context_vect_tensor_for_yp = torch.cat(context_vectors_for_yp, dim=0) #20, 1024
        print('Context_tensor_for_yp: ' + str(context_vect_tensor_for_yp.size())) 
        context_loc_for_yp_t = torch.tensor(context_loc_for_yp) #20
        print('Context_loc: ' + str(len(context_loc_for_yp_t)))

        context_vect_tensor_for_xp = torch.cat(context_vectors_for_xp, dim=0) #20, 1024
        print('Context_tensor_for_xp: ' + str(context_vect_tensor_for_xp.size()))
        context_loc_for_xp_t = torch.tensor(context_loc_for_xp)

        all_predictions = []
        all_loc = []

        for steps_y in range(3):
            predictions = self.prediction_weights[0][steps_y].forward(context_vect_tensor_for_yp) #20,1024
            print('Predictions wrt y: ' + str(predictions.size()))
            all_predictions.append(predictions)
            steps_add = torch.tensor([steps_y + 2,0])
            all_loc.append(context_loc_for_yp_t + steps_add)
        print('all_loc_y: ' + str(len(all_loc)))

        for steps_x in range(3):
            predictions = self.prediction_weights[1][steps_x].forward(context_vect_tensor_for_xp)
            print('Predictions wrt x: ' + str(predictions.size()))
            all_predictions.append(predictions)
            steps_add = torch.tensor([0, steps_x + 2])
            all_loc.append(context_loc_for_xp_t + steps_add)

        ret = torch.cat(all_predictions, dim = 0), torch.cat(all_loc, dim = 0)

        return ret

* **Our version of CPC with PixelCNN**

In [34]:
class ContextPredictionModel2(Module):

    def __init__(self, in_channels):
        super(ContextPredictionModel2, self).__init__()

        self.in_channels = in_channels

        # Input will be 1024x7x7

        # Two sets of convolutional context networks - one for vertical, one for horizontal agregation.

        # Prediction 3 steps ahead. So I will have 8 outputs.
        # [0,2:6] predict->[3,4,5:6],[1,3:6] predict->[4,5,6:6]
        # [4,6:6] predict->[3,2,1:6],[3,5:6] predict->[2,1,0:6]

        # [6:0,2] predict->[6:3,4,5],[6:1,3] predict->[6:4,5,6]
        # [6:4,6] predict->[6:3,2,1],[6:3,5] predict->[6:2,1,0]

        self.context_layers = 3
        self.context_conv = nn.Sequential()

        for layer_idx in range(self.context_layers):
            self.context_conv.add_module(f'batch_norm_{layer_idx}',nn.BatchNorm2d(self.in_channels)),
            self.context_conv.add_module(f'relu_{layer_idx}',nn.ReLU())
            self.context_conv.add_module(
                f'conv2d_{layer_idx}',
                nn.Conv2d(
                    in_channels = self.in_channels,
                    out_channels = self.in_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0
                )
            )

        self.context_conv.add_module(
            'adaptive_avg_pool',
            nn.AdaptiveAvgPool2d(output_size=1)
        )


        # Y direction predictions, X direction predictions

        self.prediction_weights = nn.ModuleList([nn.Linear(
                        in_features = self.in_channels,
                        out_features = self.in_channels,
                    ) for i in range(4)])
        

    # x: encoded patches (2, 1024, 7, 7)
    def forward(self, x): 

        z_patches_loc_list = []
        context_vectors_list = []

        for y1 in range(3): #rows
            z_patches_list = []
            for x1 in range(7): #columns

                z_patches = x[:,:,0:y1+1,0:7] #2, 1024, o 1 o 2 o 3, 7
                #print('z_patches: ' + str(z_patches.size()))
                z_patches_loc = (y1,x1) # Store pixel coordinates

                z_patches_list.append(z_patches) # itera fino 7
                #print('List: ' + str(len(z_patches_list)))
                z_patches_loc_list += [z_patches_loc] * len(z_patches)
                #print('Loc List: ' + str(len(z_patches_loc_list)))

            z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 14, 1024, o 1 o 2 o 3, 3
            #print('Tensor: ' + str(z_patches_tensor.size()))

            # Apply context model to encoded patches 
            context_vectors_temp = self.context_conv.forward(z_patches_tensor) #14, 1024, 1, 1
            #print('Context_vector_temp: ' + str(context_vectors_temp.size()))
            context_vectors_list.append(context_vectors_temp) # 3
            

        context_vectors = torch.cat(context_vectors_list, dim = 0) #42, 1024, 1, 1
        context_vectors = context_vectors.squeeze(dim=3)
        context_vectors = context_vectors.squeeze(dim=2)
        context_loc_list = torch.tensor(z_patches_loc_list)


        all_predictions = []
        all_loc = []

        for steps_y in range(4):
            predictions = self.prediction_weights[steps_y].forward(context_vectors) #42, 1024
            #print('Predictions: ' + str(predictions.size()))
            all_predictions.append(predictions)
            steps_add = torch.tensor([steps_y + 1,0])
            all_loc.append(context_loc_list + steps_add)

        ret = torch.cat(all_predictions, dim = 0), torch.cat(all_loc, dim = 0)

        return ret

* **CPC with only encoder and predicts the underlying patches**

In [35]:
class ContextPredictionModelWithDir(Module):

    def __init__(self, in_channels, direction):
        super(ContextPredictionModelWithDir, self).__init__()
        # Input will be 1024x7x7
        self.in_channels = in_channels
        self.direction = direction
        
        self.prediction_weights = nn.ModuleList([nn.Linear(
                        in_features = self.in_channels,
                        out_features = self.in_channels,
                    ) for i in range(4)])
        

    # x: encoded patches (2, 1024, 7, 7)
    def forward(self, x): 

        z_patches_loc_list = []
        context_vectors_list = []

        if self.direction == 'DOWN' or self.direction == 'UP':
            for y1 in range(3): #rows
                z_patches_list = []
                for x1 in range(7): #columns
                    if self.direction == 'DOWN':
                        z_patches = x[:, :, y1:y1+1, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1) # Store pixel coordinates
                    else:
                        z_patches = x[:, :, y1+4:y1+5, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1+4,x1) # Store pixel coordinates
                    #print('z_patches: ' + str(z_patches.size()))
                    z_patches_list.append(z_patches) 
                    #print('List: ' + str(len(z_patches_list)))
                    z_patches_loc_list += [z_patches_loc] * len(z_patches)
                    #print('Loc List: ' + str(len(z_patches_loc_list)))

                z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 14, 1024, 1, 1
                #print('Tensor: '+ str(y1) + str(z_patches_tensor.size()))
                z_patches_list.append(z_patches_tensor) # 3
        else:
            for y1 in range(7): #rows
                z_patches_list = []
                for x1 in range(3): #columns
                    if self.direction == 'RIGHT':
                        z_patches = x[:, :, y1:y1+1, x1:x1+1] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1) # Store pixel coordinates
                    else:
                        z_patches = x[:, :, y1:y1+1, x1+4:x1+5] #2, 1024, 1, 1
                        z_patches_loc = (y1,x1+4) # Store pixel coordinates
                    #print('z_patches: ' + str(z_patches.size()))
                    z_patches_list.append(z_patches) 
                    #print('List: ' + str(len(z_patches_list)))
                    z_patches_loc_list += [z_patches_loc] * len(z_patches)
                    #print('Loc List: ' + str(len(z_patches_loc_list)))

                z_patches_tensor = torch.cat(z_patches_list, dim = 0) # 14, 1024, 1, 1
                #print('Tensor: '+ str(y1) + str(z_patches_tensor.size()))
                z_patches_list.append(z_patches_tensor) # 3
            
        z_patches = torch.cat(z_patches_list, dim = 0) #42, 1024, 1, 1
        #print('Z_patches_vector: ' + str(z_patches.size()))
        z_patches = z_patches.squeeze(dim=3)
        z_patches = z_patches.squeeze(dim=2)

        context_loc_list = torch.tensor(z_patches_loc_list)

        all_predictions = []
        all_loc = []

        for steps_y in range(4):
            predictions = self.prediction_weights[steps_y].forward(z_patches) #42, 1024
            #print('Predictions: ' + str(predictions.size()))
            all_predictions.append(predictions)
            if self.direction == 'DOWN':
                steps_add = torch.tensor([steps_y + 1, 0])
            elif self.direction == 'UP':
                steps_add = torch.tensor([0 - 1 - steps_y, 0])
            elif self.direction == 'RIGHT':
                steps_add = torch.tensor([0, steps_y + 1])
            else:
                steps_add = torch.tensor([0, 0 - 1 - steps_y])
            all_loc.append(context_loc_list + steps_add)

        ret = torch.cat(all_predictions, dim = 0), torch.cat(all_loc, dim = 0)

        return ret

# **Context predictor training**

In [36]:
def run_context_predictor(args, res_encoder_model, context_predictor_model, models_store_path):

    print("RUNNING CONTEXT PREDICTOR " +str(context_predictor_model.direction)+ " TRAINING")
    
    #used to create the file where model weights are saved
    prefix = str(context_predictor_model.direction)
    best_encoder = lambda prefix: f"{prefix}_best_res_encoder_weights.pt"
    
    #upload of datasets
    dataset_train, dataset_test = get_imagenet_datasets(args.train_image_folder, args.test_image_folder, num_classes_train = args.num_classes_train, num_classes_test = args.num_classes_test)
    
    stats_csv_path = os.path.join(models_store_path, "pred_stats.csv")

    #creation of dataloaders
    random_patch_loader = get_random_patch_loader(dataset_train)
    data_loader_train = DataLoader(dataset_train, args.sub_batch_size, shuffle = True)

    params = list(res_encoder_model.parameters())
    optimizer = torch.optim.Adam(params = params, lr=0.00001)
    early_stopper = Patience(100, True)

    z_vect_similarity = dict()
    
    for epoch in range(1, args.num_epochs + 1):
        
        print("RUNNING EPOCH #" + str(epoch))
        res_encoder_model.train()
        
        for batch in data_loader_train:

            img_batch = batch['image'].to(args.device)
            patch_batch = get_patch_tensor_from_image_batch(img_batch)
            #print('Patch_batch: ' + str(patch_batch.size()))
            batch_size = len(img_batch)
            
            sub_batches_processed = 0
            batch_loss = 0
            sum_batch_loss = 0 
            patch_batch_loss = 0
            best_batch_loss = 1e10
            patch_sum_batch_loss = 0

            for i in range(len(patch_batch)): #25

                # Apply encoder to all the 49 patches of the image (64x64)
                patches_encoded = res_encoder_model.forward(patch_batch[i]) #98, 1024
                #print('Patch_encoded: ' + str(patches_encoded.size()))
                patches_encoded = patches_encoded.view(batch_size, 7,7,-1) #reshape 2, 7, 7, 1024
                #print('Reshape: ' + str(patches_encoded.size()))
                patches_encoded = patches_encoded.permute(0,3,1,2) #2, 1024, 7, 7
                #print('Permute: ' + str(patches_encoded.size()))

                for i in range(2):
                    patches_return = get_random_patches(random_patch_loader, args.num_random_patches)
                    if patches_return['is_data_loader_finished']:
                        random_patch_loader = get_random_patch_loader(dataset_train)
                    else:
                        random_patches = patches_return['patches_tensor'].to(args.device)

                # enc_random_patches = resnet_encoder.forward(random_patches).detach()
                # Apply encoder to few rendom patches
                enc_random_patches = res_encoder_model.forward(random_patches)

                # Apply context_predictor to encoded patches
                predictions, locations = context_predictor_model.forward(patches_encoded) #112, 1024
                #print('Predictions: ' + str(predictions.size())) 
                losses = []

                for b in range(len(predictions)//batch_size): #batch_size = 2

                    b_idx_start = b*batch_size
                    b_idx_end = (b+1)*batch_size

                    p_y = locations[b_idx_start][0]
                    p_x = locations[b_idx_start][1]

                    target = patches_encoded[:,:,p_y,p_x]
                    pred = predictions[b_idx_start:b_idx_end] #2,1024

                    dot_norm_val = dot_norm_exp(pred.detach().to('cpu'), target.detach().to('cpu'))
                    euc_loss_val = norm_euclidian(pred.detach().to('cpu'), target.detach().to('cpu'))

                    good_term_dot = dot(pred, target)
                    dot_terms = [torch.unsqueeze(good_term_dot,dim=0)]

                    for random_patch_idx in range(args.num_random_patches):

                        bad_term_dot = dot(pred, enc_random_patches[random_patch_idx:random_patch_idx+1])
                        dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0))

                    log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0)
                    losses.append(-log_softmax[0,])

                loss = torch.mean(torch.cat(losses))
                loss.backward()

                sub_batches_processed += img_batch.shape[0]
                patch_batch_loss += loss.detach().to('cpu')
                patch_sum_batch_loss += torch.sum(torch.cat(losses).detach().to('cpu'))

                batch_loss += patch_batch_loss
                sum_batch_loss += patch_sum_batch_loss
            
            batch_loss = batch_loss/25


            if sub_batches_processed >= args.batch_size:

                optimizer.step()
                optimizer.zero_grad()

                print(f"{datetime.datetime.now()} Loss: {batch_loss}")
                print(f"{datetime.datetime.now()} SUM Loss: {sum_batch_loss}")

                torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, "last_res_encoder_weights.pt"))
                #torch.save(context_predictor_model.state_dict(), os.path.join(models_store_path, "last_context_predictor_weights.pt"))

                if best_batch_loss > batch_loss:
                    best_batch_loss = batch_loss
                    torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, best_encoder(prefix)))
                    #torch.save(context_predictor_model.state_dict(), os.path.join(models_store_path, best_encoder(prefix)))

                for key, cos_similarity_tensor in z_vect_similarity.items():
                    print(f"Mean cos_sim for class {key} is {cos_similarity_tensor.mean()} . Number: {cos_similarity_tensor.size()}")

                z_vect_similarity = dict()

                stats = dict(
                    batch_loss = batch_loss,
                    sum_batch_loss = sum_batch_loss
                )
                write_csv_stats(stats_csv_path, stats)

                sub_batches_processed = 0
                batch_loss = 0
                sum_batch_loss = 0

**Anomaly detection evaluation**

In [37]:
def calculate_score_dir(enc_patches, predictions, locations, enc_random_patches):
    losses = []

    for b in range(len(predictions)):
        #TODO: brutto da sistemare
        if b < 21:
            p_y = locations[b][0]
            p_x = locations[b][1]
        else:
            p_y = locations[b - 7][0]
            p_x = locations[b - 7][1]

        x_t = enc_patches[:,:,p_y,p_x]
        x_tk = predictions[b]
        x_tk = x_tk.view(1,1024)

        dot_norm_val = dot_norm_exp(x_tk.detach().to(args.device), x_t.detach().to(args.device))
        euc_loss_val = norm_euclidian(x_tk.detach().to(args.device), x_t.detach().to(args.device))

        good_term_dot = dot(x_tk, x_t)
        dot_terms = [torch.unsqueeze(good_term_dot,dim=0)]

        for random_patch_idx in range(args.num_random_patches):
            bad_term_dot = dot(x_tk, enc_random_patches[random_patch_idx:random_patch_idx+1])
            dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0))

        log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0)
        losses.append(-log_softmax[0,])
        # losses.append(-torch.log(good_term/divisor))

    loss = torch.mean(torch.cat(losses))
    return loss

def run_anomaly_evaluation(args, res_encoder_down_model, res_encoder_up_model, res_encoder_right_model, res_encoder_left_model, models_store_path):

    print("RUNNING ANOMALY DETECTION")

    dataset_train, dataset_test = get_imagenet_datasets(args.train_image_folder, args.test_image_folder, train_split=1, test_split=0, num_classes_train = args.num_classes_train, num_classes_test = args.num_classes_test)
    data_loader_test = DataLoader(dataset_test, 1, shuffle = False)
    NUM_TEST_SAMPLES = dataset_test.get_number_of_samples()
    print(NUM_TEST_SAMPLES)

    random_patch_loader = get_random_patch_loader(dataset_train) 
    
    res_encoder_down_model.eval()
    res_encoder_up_model.eval()
    res_encoder_right_model.eval()
    res_encoder_left_model.eval()
    
    score_all = []
    label_all = []
    data_list = []
    
    for i, data in enumerate(data_loader_test):
        data_list.append(data)
        class_name = data['class_name'][0]
        #print('Class: ' +str(class_name))
        
        img = data['image'].to(args.device)
        patches = get_patch_tensor_from_image_batch(img)
        #print('Patch_batch: ' + str(patch_batch.size()))
        #print(data)
        image_scores = []
        with torch.no_grad():
            enc_down_patches = res_encoder_down_model(patches) #49, 1024
            #print('down size tensor: ' + str(enc_down_patches.size()))
            enc_down_patches = enc_down_patches.view(1,7,7,-1) #reshape 1, 7, 7, 1024
            #print('Reshape: ' + str(patches_encoded.size()))
            enc_down_patches = enc_down_patches.permute(0,3,1,2) #1, 1024, 7, 7
            #print('Permute: ' + str(patches_encoded.size()))
            enc_up_patches = res_encoder_up_model(patches) #49, 1024
            #print('up size tensor: ' + str(enc_up_patches.size()))
            enc_up_patches = enc_up_patches.view(1,7,7,-1) #reshape 1, 7, 7, 1024
            #print('Reshape: ' + str(patches_encoded.size()))
            enc_up_patches = enc_up_patches.permute(0,3,1,2) #1, 1024, 7, 7
            #print('Permute: ' + str(patches_encoded.size()))
            enc_right_patches = res_encoder_right_model(patches) #49, 1024
            #print('right size tensor: ' + str(enc_right_patches.size()))
            enc_right_patches = enc_right_patches.view(1,7,7,-1) #reshape 1, 7, 7, 1024
            #print('Reshape: ' + str(patches_encoded.size()))
            enc_right_patches = enc_right_patches.permute(0,3,1,2) #1, 1024, 7, 7
            #print('Permute: ' + str(patches_encoded.size()))
            enc_left_patches = res_encoder_left_model(patches) #49, 1024
            #print('left size tensor: ' + str(enc_left_patches.size()))
            enc_left_patches = enc_left_patches.view(1,7,7,-1) #reshape 1, 7, 7, 1024
            #print('Reshape: ' + str(patches_encoded.size()))
            enc_left_patches = enc_left_patches.permute(0,3,1,2) #1, 1024, 7, 7
            #print('Permute: ' + str(patches_encoded.size()))
            
            patches_return = get_random_patches(random_patch_loader, args.num_random_patches)
            if patches_return['is_data_loader_finished']:
                random_patch_loader = get_random_patch_loader(dataset_train)
            else:
                random_patches = patches_return['patches_tensor'].to(args.device)
            
            enc_down_random_patches = res_encoder_down_model.forward(random_patches) #49, 1024
            #print('down size encoded tensor: ' + str(enc_down_patches.size()))
            enc_up_random_patches = res_encoder_up_model.forward(random_patches) #49, 1024
            #print('up size encoded tensor: ' + str(enc_up_patches.size()))
            enc_right_random_patches = res_encoder_right_model.forward(random_patches) #49, 1024
            #print('right size encoded tensor: ' + str(enc_right_patches.size()))
            enc_left_random_patches = res_encoder_left_model.forward(random_patches) #49, 1024
            #print('left size encoded tensor: ' + str(enc_left_patches.size()))
            
            predictions_down = enc_down_patches[:,:,3:7,:]  #x_tk
            predictions_down = predictions_down.permute(0,2,3,1)
            predictions_down = predictions_down.view(28,1024) # numero totale di predictions (4*7)
            
            predictions_up = enc_up_patches[:,:,0:4,:]   #x_tk
            predictions_up = predictions_up.permute(0,2,3,1)
            predictions_up = predictions_up.view(28,1024) # numero totale di predictions (4*7)
            
            predictions_right = enc_right_patches[:,:,:,3:7]  #x_tk
            predictions_right = predictions_right.permute(0,2,3,1)
            predictions_right = predictions_right.reshape(28,1024) # numero totale di predictions (4*7)      
            
            predictions_left = enc_left_patches[:,:,:,0:4]  #x_tk
            predictions_left = predictions_left.permute(0,2,3,1)
            predictions_left = predictions_left.reshape(28,1024) # numero totale di predictions (4*7)
            
            # locations of starting point xt
            locations_down = [] 
            locations_up = []
            locations_right = []
            locations_left = []
            for x in range(7): #rows
                for y in range(7): #columns
                    pos = (x,y)
                    if x < 3:
                        locations_down.append(pos)
                        if y < 3:
                            locations_right.append(pos)
                        elif y != 3:
                            locations_left.append(pos)
                    elif x != 3:
                        locations_up.append(pos)
                        if y < 3:
                            locations_right.append(pos)
                        elif y != 3:
                            locations_left.append(pos)
                    else:
                        if y < 3:
                            locations_right.append(pos)
                        elif y != 3:
                            locations_left.append(pos)
            
            score_down = calculate_score_dir(enc_down_patches, predictions_down, locations_down, enc_down_random_patches)
            #print('Down score: ' + str(score_down.item()))
            image_scores.append(score_down)
            score_up = calculate_score_dir(enc_up_patches, predictions_up, locations_up, enc_up_random_patches)
            #print('Up score: ' + str(score_up.item()))
            image_scores.append(score_up)
            score_right = calculate_score_dir(enc_right_patches, predictions_right, locations_right, enc_right_random_patches)
            #print('Right score: ' + str(score_right.item()))
            image_scores.append(score_right)
            score_left = calculate_score_dir(enc_left_patches, predictions_left, locations_left, enc_left_random_patches)
            #print('Left score: ' + str(score_left.item()))
            image_scores.append(score_left)
            #print(len(image_scores))
            
            if class_name == 'good':
                label_all.append(0)
            else:
                label_all.append(1)
            
            avg_score = sum(image_scores) / len(image_scores) #Media delle loss dei 4 modelli
            #max_score = torch.max(torch.cat(image_scores)) Loss massima tra i 4 modelli
            score_all.append(avg_score)

    
    score_all = [s.cpu().numpy() for s in score_all]
    score_all = np.vstack(score_all)
    score_all = np.concatenate(score_all)
    
    # Compute threshold -> predictions
    normal_ratio = sum(1 for a in label_all if a == 0) / len(label_all)
    threshold = np.percentile(score_all, 100 * normal_ratio)
    predictions = np.zeros(len(score_all))
    predictions[score_all > threshold] = 1
    
    #with open('{0}/Bottle.csv'.format(models_store_path), mode='w') as csv_file:
    #    fieldnames = ['class_name', 'score', 'anomaly']
    #    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
    #    writer.writeheader()
    for i, data in enumerate(data_list):
        print("IMAGE #" + str(i))
        print(data['class_name'][0])
        print(score_all[i])
        print(predictions[i])
            
#      writer.writerow({
#          'class_name': data['class_name'],
#          'score': avg_score,
#          'anomaly': pred_label})
    
    auc = roc_auc_score(label_all, score_all)
    f1 = compute_pre_recall_f1(label_all, predictions)
    ap = average_precision_score(label_all, score_all)
    print('AUC: ' + str(auc) + '\nF1: ' + str(f1) + '\nAverage Precision: ' + str(ap))
                     


# **MAIN**

* **Anomaly Detection MAIN**
> Main for detecting anomalies in pictures using Contrastive Predictive Coding

In [38]:
IMG_SIZE = (768,768)
parser = argparse.ArgumentParser(description='Contrastive predictive coding params')

# mode = 'train_encoder_context_prediction'
# mode = 'anomaly_evaluation'
parser.add_argument('-mode', default='train_encoder_context_prediction' , type=str)
# chest-xray-pneumonia/chest_xray/train' 
# mvtec-ad/bottle/train'          num_class 1         
parser.add_argument('-train_image_folder', default='../input/mvtec-ad/bottle/train', type=str)
parser.add_argument('-num_classes_train', default=1, type=int)
# chest-xray-pneumonia/chest_xray/train'  num_class 2
# mvtec-ad/bottle/test'                   num_class 4
parser.add_argument('-test_image_folder', default='../input/mvtec-ad/bottle/test', type=str)
parser.add_argument('-num_classes_test', default=4, type=int)
parser.add_argument('-batch_size', default=16, type=int)
parser.add_argument('-sub_batch_size', default=2, type=int)
parser.add_argument('-num_random_patches', default=15, type=int)
parser.add_argument('-num_epochs', default=1, type=int)
# cpu or cuda
parser.add_argument('-device', default='cuda', type=str)

args, args_other = parser.parse_known_args()

print(f"Running CPC with args {args}")

Z_DIMENSIONS = 1024
DIRECTIONS = ['DOWN', 'UP', 'RIGHT', 'LEFT']

stored_models_root_path = "trained_models"
if not os.path.isdir(stored_models_root_path):
    os.mkdir(stored_models_root_path)
stored_eval_root_path = "evaluation"
if not os.path.isdir(stored_eval_root_path):
    os.mkdir(stored_eval_root_path)

if args.mode == 'train_encoder_context_prediction':
    
    for i, direc in enumerate(DIRECTIONS):
        res_encoder_model = None
        context_predictor_model = None
        #res_encoder_model = ResEncoderModel().to(args.device)
        # ResNet18 v2 up to the third residual block
        res_encoder_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device)
        context_predictor_model = ContextPredictionModelWithDir(in_channels=Z_DIMENSIONS, direction=direc).to(args.device)
        
        # Models training
        model_store_folder = get_next_model_folder(direc, stored_models_root_path)
        os.mkdir(model_store_folder)
        run_context_predictor(args, res_encoder_model, context_predictor_model, model_store_folder)
    
    
if args.mode == 'anomaly_evaluation':
    # Evaluation
    res_encoder_down_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device)
    res_encoder_up_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device)
    res_encoder_right_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device) 
    res_encoder_left_model = ResNet(3, ResBlock, [2, 2, 2, 2]).to(args.device) 

    # Down encoder
    #models_store_folder = os.path.join(stored_models_root_path, 'Down_model_run_1') #Mettere il numero della run dopo trattino basso
    #res_encoder_down_weights_path = os.path.join(models_store_folder, "best_res_encoder_weights.pt")
    res_encoder_down_weights_path = '/kaggle/input/models80/down80_best_res_encoder_weights.pt'
    print(f"Loading res encoder down weights from {res_encoder_down_weights_path}")
    res_encoder_down_model.load_state_dict(torch.load(res_encoder_down_weights_path))
    
    # Up encoder
    #models_store_folder = os.path.join(stored_models_root_path, 'Up_model_run_1') #Mettere il numero della run dopo trattino basso
    #res_encoder_up_weights_path = os.path.join(models_store_folder, "best_res_encoder_weights.pt")
    res_encoder_up_weights_path = '/kaggle/input/models80/up80_best_res_encoder_weights.pt'
    print(f"Loading res encoder up weights from {res_encoder_up_weights_path}")
    res_encoder_up_model.load_state_dict(torch.load(res_encoder_up_weights_path))
    
    # Right encoder    
    #models_store_folder = os.path.join(stored_models_root_path, 'Right_model_run_1') #Mettere il numero della run dopo trattino basso
    #res_encoder_right_weights_path = os.path.join(models_store_folder, "best_res_encoder_weights.pt")
    res_encoder_right_weights_path = '/kaggle/input/models80/right80_best_res_encoder_weights.pt'
    print(f"Loading res encoder right weights from {res_encoder_right_weights_path}")
    res_encoder_right_model.load_state_dict(torch.load(res_encoder_right_weights_path))
    
    # Left encoder
    #models_store_folder = os.path.join(stored_models_root_path, 'Left_model_run_1') #Mettere il numero della run dopo trattino basso
    #res_encoder_left_weights_path = os.path.join(models_store_folder, "best_res_encoder_weights.pt")
    res_encoder_left_weights_path = '/kaggle/input/models80/left80_best_res_encoder_weights.pt'
    print(f"Loading res encoder left weights from {res_encoder_left_weights_path}")
    res_encoder_left_model.load_state_dict(torch.load(res_encoder_left_weights_path))
    
    run_anomaly_evaluation(args, res_encoder_down_model, res_encoder_up_model, res_encoder_right_model, res_encoder_left_model, stored_eval_root_path)

Running CPC with args Namespace(batch_size=16, device='cuda', mode='train_encoder_context_prediction', num_classes_test=4, num_classes_train=1, num_epochs=1, num_random_patches=15, sub_batch_size=2, test_image_folder='../input/mvtec-ad/bottle/test', train_image_folder='../input/mvtec-ad/bottle/train')
STARTING DOWN RUN 3! Storing the models at trained_models/DOWN_model_run_3
RUNNING CONTEXT PREDICTOR DOWN TRAINING
RUNNING EPOCH #1


KeyboardInterrupt: 