In [None]:
import os
import re
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np

class HomeworkDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_pairs = self.load_image_pairs()

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        start_img_path, end_img_path, time_skip = self.image_pairs[idx]

        if 'start_image_cache' not in self.__dict__:
            self.start_image_cache = {}
            self.end_image_cache = {}

        if idx in self.start_image_cache:
            start_img = self.start_image_cache[idx]
        else:
            start_img = Image.open(start_img_path).convert('RGB')
            self.start_image_cache[idx] = start_img

        if idx in self.end_image_cache:
            end_img = self.end_image_cache[idx]
        else:
            end_img = Image.open(end_img_path).convert('RGB')
            self.end_image_cache[idx] = end_img

        print(f"Start image shape after reading (H, W, C): {np.array(start_img).shape}")
        print(f"End image shape after reading (H, W, C): {np.array(end_img).shape}")

        if self.transform:
            aug = transforms.RandomRotation(30)
            start_img = aug(start_img)
            end_img = aug(end_img)
            start_img = self.transform(start_img)
            end_img = self.transform(end_img)

            print(f"Transformed start image shape: {start_img.shape}")
            print(f"Transformed end image shape: {end_img.shape}")

            # convert tensors back to PIL images for proper visualization
            plt.subplot(1, 2, 1)
            plt.imshow(transforms.ToPILImage()(start_img))
            plt.title("Start Image")
            plt.subplot(1, 2, 2)
            plt.imshow(transforms.ToPILImage()(end_img))
            plt.title("End Image")
            plt.show()

        print(f"Loading image pair {idx+1}/{self.__len__()}: Time skip: {time_skip} months")

        return start_img, end_img, time_skip


    def extract_location_from_path(self, folder_path):
        parts = folder_path.split('/')
        dataset_index = parts.index('Homework Dataset')
        if 'images' in parts:
            images_index = parts.index('images')
            location = '/'.join(parts[dataset_index + 1 : images_index])
        else:
            location = '/'.join(parts[dataset_index + 1 :])

        return location

    def extract_year_and_month(self, image_name):
        match = re.search(r"(\d{4})_(\d{2})", image_name.split("/")[-1])
        if match:
            year, month = map(int, match.groups())
            #print(f"Extracted year: {year}, month: {month} from {image_name}")
            return year, month
        else:
            return 0, 0

    def calculate_time_skip(self, start_image_name, end_image_name):
        start_year, start_month = self.extract_year_and_month(start_image_name)
        end_year, end_month = self.extract_year_and_month(end_image_name)
        time_skip = (end_year - start_year) * 12 + (end_month - start_month)
        return abs(time_skip)


    def load_image_pairs(self):
        image_pairs = []

        # Iterate through subdirectories in the main folder
        for dir_name in os.listdir(self.folder_path):
            dir_path = os.path.join(self.folder_path, dir_name)
            images_dir = os.path.join(dir_path, "images")
            image_files = os.listdir(images_dir)

            # Create pairs of images within the "images" directory
            for i in range(len(image_files) - 1):
                start_image = os.path.join(images_dir, image_files[i])
                end_image = os.path.join(images_dir, image_files[i + 1])

                start_year, start_month = self.extract_year_and_month(start_image)
                end_year, end_month = self.extract_year_and_month(end_image)

                # Check if the start image is not from the future compared to the end image
                if start_year > end_year or (start_year == end_year and start_month > end_month):
                    continue

                location = self.extract_location_from_path(images_dir)
                time_skip = self.calculate_time_skip(start_image, end_image)

                #print(f"Start Image: year {start_year} month {start_month} with location {location} from {start_image}")
                #print(f"End Image: year {end_year} month {end_month} with location {location} from {end_image}")
                #print(f"Time Skip: {time_skip}")

                image_pairs.append((start_image, end_image, time_skip))

        return image_pairs

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = HomeworkDataset("/content/drive/MyDrive/Homework Dataset", transform=transform)
train_size = int(0.7 * len(dataset))
val_size = (len(dataset) - train_size) // 2
test_size = len(dataset) - train_size - val_size

train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomModel, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten the input
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(x.size(0), 3, 128, 128)  # reshaping the output to match target
        return x


In [None]:
# model = CustomModel(input_size, hidden_size, output_size)
model = CustomModel(49152, 1024, 49152)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()  # Suitable for image generation

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)


def train():
    model.train()
    total_train_loss = 0
    for idx, batch in enumerate(train_loader):
        start_img, end_img, _ = batch
        #start_img, end_img = start_img.to(device), end_img.to(device)
        optimizer.zero_grad()
        output = model(start_img)
        assert output.shape == end_img.shape, f"Output shape: {output.shape}, Expected shape: {end_img.shape}"
        loss = criterion(output, end_img)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss

def val():
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            start_img, end_img, _ = batch
            #start_img, end_img = start_img.to(device), end_img.to(device)
            output = model(start_img)
            assert output.shape == end_img.shape, f"Output shape: {output.shape}, Expected shape: {end_img.shape}"
            loss = criterion(output, end_img)
            total_val_loss += loss.item()
    avg_val_loss = total_val_loss / len(val_loader)
    return avg_val_loss

def run(epochs):
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        train_loss = train()
        val_loss = val()

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{epochs} - Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

    plt.plot(range(epochs), train_losses, label="Training Loss")
    plt.plot(range(epochs), val_losses, label="Validation Loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()