In [99]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plot
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

In [None]:
!unzip Homework_Dataset.zip

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

class GlobalDataset(Dataset):
    def __init__(self, base_dir, transform = None, device: torch.device = get_default_device()):
        self.base_dir = base_dir
        self.transform = transform
        self.image_paths = {}
        self.data = []
        self.image_cache = {}
        self.load()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if item[0] in self.image_cache:
            start_image = self.image_cache[item[0]]
        else:
            start_image = Image.open(item[0])
            if self.transform:
                start_image = self.transform(start_image)
            start_image = start_image.reshape((1, 49152))
            self.image_cache[item[0]] = start_image

        if item[1] in self.image_cache:
            end_image = self.image_cache[item[1]]
        else:
            end_image = Image.open(item[1])
            if self.transform:
                end_image = self.transform(end_image)
            end_image = end_image.reshape((1, 49152))
            self.image_cache[item[1]] = end_image

        time_skip = item[2]
        return start_image, end_image, time_skip

    def compute_and_sort(self, date1, date2):
        date1_split = date1.split('_')
        year1 = int(date1_split[0])
        month1 = int(date1_split[1])
        date2_split = date2.split('_')
        year2 = int(date2_split[0])
        month2 = int(date2_split[1])

        if year1 > year2 or (year1 == year2 and month1 >= month2):
            return (year1 - year2) * 12 + month1 - month2, True
        else:
            return (year2 - year1) * 12 + month2 - month1, False

    def load(self):
        for dir in os.listdir(self.base_dir):
            current_dir = os.path.join(os.path.join(self.base_dir, dir), "images")
            for image_path in os.listdir(current_dir):
                path = os.path.join(current_dir, image_path)
                words = image_path.split('_')
                year_month = words[2] + "_" + words[3]
                if current_dir not in self.image_paths:
                    self.image_paths[current_dir] = {}
                self.image_paths[current_dir][path] = year_month
        for dir in self.image_paths:
            for path1 in self.image_paths[dir]:
                for path2 in self.image_paths[dir]:
                    months, first_bigger = self.compute_and_sort(self.image_paths[dir][path1], self.image_paths[dir][path2])
                    if first_bigger:
                        self.data.append([path2, path1, months])
                    else:
                        self.data.append([path1, path2, months])

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(49152, 128).to(device)
        self.act1 = nn.ReLU().to(device)
        self.hidden2 = nn.Linear(128, 64).to(device)
        self.act2 = nn.ReLU().to(device)
        self.output = nn.Linear(65, 49152).to(device)
        self.act_output = nn.Sigmoid().to(device)

    def forward(self, x, months):
        x = self.act1(self.hidden1(x).to(device)).to(device)
        x = self.act2(self.hidden2(x)).to(device)
        months = months[:, None, None]
        x = torch.concatenate((x, months), axis = 2)
        x = self.output(x).to(device)
        x = self.act_output(x).to(device)
        return x

device = get_default_device()

transform = transforms.Compose([transforms.ToTensor(), transforms.RandomRotation((-10, 10)), transforms.RandomGrayscale(0.15), transforms.RandomVerticalFlip(0.2)])
global_dataset = GlobalDataset(base_dir = "Homework Dataset", transform = transform)

train, validation, test = torch.utils.data.random_split(global_dataset, [0.7, 0.15, 0.15])

train_loader = DataLoader(train, batch_size = 32, shuffle = True)
validation_loader = DataLoader(validation, batch_size = 32, shuffle = False)
test_loader = DataLoader(test, batch_size = 32, shuffle = False)

model = Model()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_accuracies = []
train_losses = []
validation_accuracies = []
test_accuracies = []

def train():
    total = 0
    ok = 0
    train_loss = 0
    model.train()
    for mini_batch in train_loader:
        data = torch.tensor(mini_batch[0], device = device)
        label = torch.tensor(mini_batch[1], device = device)
        y_pred = model(data, torch.tensor(mini_batch[2], device = device))
        y = label
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += len(mini_batch)
        ok += (y == y_pred).sum().item()
        train_loss += loss.item()
    train_losses.append(train_loss)
    train_accuracies.append(ok / total)

def val(test: bool):
    total = 0
    ok = 0
    model.eval()
    loader = validation_loader if test == False else test_loader
    for item in loader:
        data = torch.tensor(item[0], device = device)
        label = torch.tensor(item[1], device = device)
        y_pred = model(data, torch.tensor(item[2], device = device))
        y = label
        total += 1
        ok += (y == y_pred).sum().item()
    if test == False:
        validation_accuracies.append(ok / total)
    else:
        test_accuracies.append(ok / total)

def run(epoch_number: int):
    epochs = tqdm(range(epoch_number))
    for i in epochs:
        train()
        val(False)
        epochs.set_postfix_str(f"Train Loss: {train_losses[i]}, Train Accuracy: {train_accuracies[i]}, Validate Accuracy: {validation_accuracies[i]}")
    val(True)
    print(f"Test Accuracy: {test_accuracies[0]}")

run(epoch_number = 50)
plot.plot(train_accuracies)
plot.savefig("train_accuracies.png")
plot.clf()
plot.plot(train_losses)
plot.savefig("train_losses.png")
plot.clf()
plot.plot(validation_accuracies)
plot.savefig("validation_accuracies.png")
plot.clf()
