In [3]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from PIL import Image
from torchbearer.cv_utils import DatasetValidationSplitter
from livelossplot import PlotLosses

from torch import nn
from tqdm import tqdm

In [None]:
import cv2
import numpy as np 

def padding(img, shape_r=480, shape_c=640, channels=3):
    img_padded = np.zeros((shape_r, shape_c, channels), dtype=np.uint8)
    if channels == 1:
        img_padded = np.zeros((shape_r, shape_c), dtype=np.uint8)

    original_shape = img.shape
    rows_rate = original_shape[0]/shape_r
    cols_rate = original_shape[1]/shape_c

    if rows_rate > cols_rate:
        new_cols = (original_shape[1] * shape_r) // original_shape[0]
        img = cv2.resize(img, (new_cols, shape_r))
        if new_cols > shape_c:
            new_cols = shape_c
        img_padded[:, ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img
    else:
        new_rows = (original_shape[0] * shape_c) // original_shape[1]
        img = cv2.resize(img, (shape_c, new_rows))
        if new_rows > shape_r:
            new_rows = shape_r
        img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), :] = img

    return img_padded

In [None]:
class Cat200Loader:
    def __init__(self, root_path, batch_size=8, frac_train_to_be_val=0.2):
        self.datasets = {}
        self.loaders = {}
        imgs_path = lambda x: f'{root_path}/{x}/Stimuli/'
        maps_path = lambda x: f'{root_path}/{x}/FIXATIONMAPS/'
        
        self.datasets['test'] = CustomImageDataset(imgs_path('test'), transform=transform1)
        dataset = CustomImageDataset(imgs_path('train'), maps_path('train'), transform=transform1, target_transform=transform2)
        splitter = DatasetValidationSplitter(len(dataset), frac_train_to_be_val)
        self.datasets['val'] = splitter.get_val_dataset(dataset)
        self.datasets['train'] = splitter.get_train_dataset(dataset)
        
        self.loaders['train'] = DataLoader(self.datasets['train'], batch_size=batch_size, shuffle = True, pin_memory=True)
        self.loaders['val'] = DataLoader(self.datasets['val'], batch_size=batch_size, shuffle = True, pin_memory=True)
        self.loaders['test'] = DataLoader(self.datasets['test'], batch_size=batch_size, shuffle = False, pin_memory=True)
        
        
class CustomImageDataset(Dataset):
    def __init__(self, imgs_path, fix_maps_path=None, transform=None, target_transform=None):
        self.images = [os.path.join(imgs_path, category,img) for category in os.listdir(imgs_path)
                                 for img in os.listdir(os.path.join(imgs_path, category)) if img.endswith('.jpg')]
        self.maps = [os.path.join(fix_maps_path, category,img) for category in os.listdir(fix_maps_path)
                                 for img in os.listdir(os.path.join(fix_maps_path, category)) if img.endswith('.jpg')] if fix_maps_path else None
        self.transform = transform
        self.target_transform = target_transform
        self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = cv2.imread(self.images[idx])
        image = padding(image, image_size1, image_size2, 3).astype('float')
        image = np.rollaxis(image, 2, 0)  
        if self.maps:
            fix_map = cv2.imread(self.maps[idx],0)
            fix_map = padding(fix_map, shape_r_gt, shape_c_gt, 1).astype('float')
        if self.transform:
            image = torch.tensor(image,dtype=torch.float)
            if image.shape[0] == 1:
                image = image.expand(3,image_size1,image_size2)
            image = self.norm(image)
            if self.maps:
                fix_map = torch.tensor(fix_map,dtype=torch.float)
                fix_map = fix_map.repeat(1,8,8)
        
        catt = torch.cat([image, fix_map], 0)
        return catt / 255.0, self.images[idx], self.maps[idx]

In [4]:
class Trainer:
    def __init__(self, model, criterion, optimizer, loaders):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.criterion = criterion.to(self.device)
        self.optimizer = optimizer
        self.loaders = loaders
        
    def run_trainer(self, epochs):
        liveloss = PlotLosses()
        for epoch in range(epochs):
            self.logs = {}
            
            self.model.train()
            self.run_epoch('train', epoch)
            
            self.model.eval()
            with torch.no_grad():
                self.run_epoch('val', epoch)
                
            liveloss.update(self.logs)
            liveloss.send()
                
    def run_epoch(self, phase, epoch):
        running_loss = 0.0
        for x, img, fmap in tqdm(self.loaders.loaders[phase]):
            x_true, y_true = x[:,:-1,:,:], x[:,1,:shape_r_gt,:shape_c_gt].unsqueeze(1)
            x_true, y_true = x_true.to(self.device), y_true.to(self.device)
            y_pred = self.model(x_true)
            loss = self.criterion(y_pred, y_true, self.model.prior.clone())
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            running_loss += loss.detach() * x_true.size(0)
            print(loss.detach())
            if phase == 'val':
                plt.imshow(x_true[0][0].data.cpu().numpy(),cmap='gray')
                plt.show()
                plt.imshow(y_pred[0][0].data.cpu().numpy(),cmap='gray')
                plt.show()
            
        epoch_loss = running_loss / len(self.loaders.loaders[phase].dataset)
        self.logs[f'{phase}_loss'] = epoch_loss.item()
                

In [None]:
model = UNet(3, 1)


# freezing Layer
# last_freeze_layer = 23
# for i,param in enumerate(model.parameters()):
#     if i < last_freeze_layer:
#         param.requires_grad = False

    
criterion = ModMSELoss(shape_r_gt,shape_c_gt)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,weight_decay=0.0005,momentum=0.9,nesterov=True)
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)

loaders = Cat200Loader('cat2000')

In [None]:
trainer = Trainer(model, criterion, optimizer, loaders)

trainer.run_trainer(3)