In [None]:
"""
Load Data as 4 images in a 2x2 grid
"""
from torchvision import transforms as T
from random import sample, shuffle
import os
import torch

data_transforms = {
    'train': T.Compose([
        T.Resize((224,224)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': T.Compose([
        T.Resize((224,224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'inverse': T.Compose([
        T.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        T.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ]),
}

# Create ImageNet dataset
from torchvision.datasets import ImageNet
imagenet = ImageNet(
    root='~/datasets/ImageNet',
    split="val",
    transform=data_transforms['train']
)

# Create oxford pets binary classification dataset (cats vs dogs)
from torchvision.datasets import OxfordIIITPet
oxfordpets = OxfordIIITPet(
    root='~/datasets/',
    split='trainval',
    transform=data_transforms['train'],
    target_transform=lambda x: x+1 in [1, 6, 7, 8, 10, 12, 21, 24, 27, 28, 33, 34]
    
)

class QuartileDataset(torch.utils.data.Dataset):
    """Compose 4 images into one image"""

    def __init__(self, dataset_1, dataset_2):
        """
        Args:
            dataset_1 (Dataset): The main dataset with target images
            dataset_2 (Dataset): The data used to fill the other 3 spots
        """
        self.dataset_1 = dataset_1
        self.dataset_2 = dataset_2
        
    def __len__(self):
        return len(self.dataset_1)
    
    def __getitem__(self, idx):
        X1, Y = self.dataset_1[idx]
        X2, _ = self.dataset_2[(idx) % len(self.dataset_2)]
        X3, _ = self.dataset_2[(2*idx) % len(self.dataset_2)]
        X4, _ = self.dataset_2[(3*idx) % len(self.dataset_2)]
        if (idx % 4) == 0:
            pass
        elif (idx % 4) == 1:
            X1, X2 = X2, X1
        elif (idx % 4) == 2:
            X1, X3 = X3, X1
        elif (idx % 4) == 3:
            X1, X4 = X4, X1
        h1 = torch.cat((X1, X2), dim=2)
        h2 = torch.cat((X3, X4), dim=2)
        X = torch.cat([h1, h2], dim=1)
        return X, Y

testloader = torch.utils.data.DataLoader(
    QuartileDataset(oxfordpets, imagenet),
    batch_size=20,
    num_workers=2,
    shuffle=True
)

"""
Load a pytorch network (vgg19) and retrain on the 2x2 grid data.
"""
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True)
model.classifier[6] = torch.nn.Linear(4096, 1)
#model = torch.load("latest_fres.pt")

criterion = torch.nn.BCELoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

ACC = []
LOSS = []
min_loss = float("inf")

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

for epoch in range(100):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, (X, Y) in enumerate(testloader):
        optimizer.zero_grad()

        Y_pred = torch.sigmoid(model(X).flatten())
        loss = criterion(Y_pred, Y.float())
        loss.backward()
        optimizer.step()
        
        if loss < min_loss:
            min_loss = loss
            torch.save(model, "finetuned.pt")
        
        correct_predictions = Y == (Y_pred > 0.5)
        ACC.append(correct_predictions.sum() / correct_predictions.numel())
        LOSS.append(loss.item())
        clear_output(wait=True)
        print("{} - {} - {}%".format(i, LOSS[-1], int(100 * ACC[-1])))
        plt.plot(ACC, label="Acc")
        plt.plot(LOSS, label="Loss")
        plt.legend(loc="upper left")
        plt.show()