In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import wandb
import os
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import numpy as np
import json

In [2]:
# Defining hyperparameters and certain settings
test_run = False  # Set to False during actual training
test_run_size = 1024  # Number of image pairs used in a test run
training_experiment = False
val_percent = 0.2  # Percent of images used for validation
batch_size = 16  # batch size
lr = 0.05  # Learning Rate
random_seed = 99  # Don't Change. Random Seed for train_test_split.
momentum = 0.9  # If using SGD.
epochs = 20
loss_fn = nn.MSELoss()

In [3]:
json_name = "dataset_stats.json"
device = torch.device("cuda")

with open(json_name) as file:
    stats = json.load(file)
    std = [round(i, 3) for i in stats["std"]]
    mean = [round(i, 3) for i in stats["mean"]]
    n = stats["n"]

with open("train_label.txt", "r") as f:
    labels = [float(label.strip("\n")) for label in f.readlines()]

def prepare_image(image):
    global mean
    global std
    # Converts image to RGB
    if image.mode != 'RGB':
        image = image.convert("RGB")
    # Adds necessary transforms to image
    Transform = transforms.Compose([
        # Scales the short side to 256px. Aspect ratio unchanged. Then center
        # crops to make the size of all images equal
        transforms.Resize(256),
        transforms.CenterCrop([256, 256]),
        # Converts PIL Image to tensor
        transforms.ToTensor(),
        # Normalises each channel by subtracting the mean from each channel and
        # dividing by the std
        transforms.Normalize(mean, std)
    ])
    image = Transform(image)
    return image

# Given a batch of images, it applies prepare_image() on each and returns
# a tensor of image pairs
def prepare_images(image_numbers):
    #unpop_images = [Image.open("image_dataset/"+str(i[1][0])+".jpg") for i in images.iterrows()]
    image_tensor = torch.stack([prepare_image(Image.open("train_images/" + str(number) + ".jpg")) for number in image_numbers])
    return image_tensor

train_transform = transforms.Compose([transforms.RandomResizedCrop([224, 224]), transforms.RandomHorizontalFlip()])
centre_crop = transforms.CenterCrop([224, 224])

In [5]:
images = list(range(len(labels)))

train_images, val_images, train_labels, val_labels = train_test_split(images, labels, test_size = val_percent, random_state=random_seed, shuffle = True)

In [43]:
images = prepare_images(val_images[:100])
images[0].size()

torch.Size([3, 256, 256])

In [44]:
images = train_transform(images)

In [28]:
class IIPAModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        if model_name == "resnet":
            self.model = torchvision.models.resnet50()
            self.model.fc = torch.nn.Linear(in_features=2048, out_features=1)
            nn.init.kaiming_uniform_(self.model.fc.weight)
        else:
            self.model = EfficientNet.from_pretrained(model_name, num_classes = 1)
            nn.init.kaiming_uniform_(self.model._fc.weight)

    def forward(self, batch):
        preds = self.model(batch)
        return preds


In [32]:
image = train_transform(images)[0]
image = image.unsqueeze(0)

In [35]:
model = IIPAModel("resnet")

In [62]:
preds = model(images)[:,0]

In [61]:
targets = torch.Tensor(val_labels[:100])

In [64]:
loss = loss_fn(preds, targets)

In [65]:
loss

tensor(29.6715, grad_fn=<MseLossBackward>)

In [66]:
preds

tensor([13.0924,  3.1055,  2.2793,  2.8653,  3.6857,  3.6473,  2.5182,  2.3701,
         2.1547,  7.1868,  5.2493,  2.4435,  3.0774,  3.4904,  2.2854,  4.7625,
         4.2139,  2.7889,  2.1392,  2.2298,  2.9881,  2.5720,  2.2757,  2.4911,
         2.3010,  3.9357,  3.9951,  2.0147,  2.5054,  5.5637,  2.6945,  3.6643,
         2.4986,  2.2204,  3.1293,  2.4174,  2.5322,  2.4223,  2.3256,  2.8337,
         2.8229,  2.4854,  2.3356,  3.7550,  3.7580,  2.7149,  2.6535,  3.6883,
         2.2592,  2.8344,  2.7861,  2.8379,  3.0066,  2.1901,  2.2175,  3.6012,
         2.0512,  2.5451,  2.9383,  3.3911,  2.7839,  2.4553,  3.9779,  2.5236,
         3.1264,  2.9413,  6.0884,  2.2943,  2.5718,  3.1930,  2.1172,  5.1770,
         7.1296,  1.8702,  3.6414, 11.0113,  2.6712,  4.7296,  5.4220,  2.6092,
         2.7502,  3.3478,  2.4422,  2.5901,  2.5474,  3.5971,  2.7060,  2.4622,
         2.3436,  2.5489,  2.0326,  2.0908,  2.9739,  2.7572,  3.1074,  1.8305,
         3.0023,  2.3739,  2.1038,  2.30