<a target="_blank" href="https://colab.research.google.com/github/Tedyst/FII-AdvRN/blob/master/lab4.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [68]:
import torch
import os
import random
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch import GradScaler
from torch import nn, optim
import tqdm
import torchvision.transforms.functional
import torchvision.transforms.v2 as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [69]:
class SpaceNetDataset(Dataset):
    def __init__(self, folder):
        self.files = {}
        self.combinations = []

        for root, _, files in os.walk(folder):
            for file in files:
                year, month = file.split('_')[2:4]
                map_id = file.split('_')[-2]
                if map_id not in self.files:
                    self.files[map_id] = []
                image = Image.open(os.path.join(root, file))
                image_np = torchvision.transforms.functional.pil_to_tensor(image)
                self.files[map_id].append((int(year) * 12 + int(month), image_np))

        for place in self.files.values():
            for a in place:
                for b in place:
                    if a[0] < b[0]:
                        self.combinations.append((a, b))

        random.shuffle(self.combinations)

        self.count = len(self.combinations)

    def __len__(self):
        return len(self.combinations)

    def __getitem__(self, idx):
        (a_date, a_image), (b_date, b_image) = self.combinations[idx]

        rotations = torch.randint(0, 4, (1,))

        a_image = torchvision.transforms.functional.rotate(a_image, int(90 * rotations)).to(torch.float32)
        b_image = torchvision.transforms.functional.rotate(b_image, int(90 * rotations)).to(torch.float32)

        return (a_image, b_image, b_date - a_date)

dataset = SpaceNetDataset("./Advanced-Topics-in-Neural-Networks-Template-2024/Lab04/Dataset/")

In [70]:
class SplitDataset(SpaceNetDataset):
    def __init__(self, dataset: SpaceNetDataset, split_start, split_end):
        self.combinations = dataset.combinations[int(len(dataset.combinations) * split_start):int(len(dataset.combinations) * split_end)]

In [71]:
class TransformationDataset(Dataset):
    def __init__(self, dataset: SpaceNetDataset, transformation):
        self.dataset = dataset
        self.transformation = transformation

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        a_image, b_image, delta = self.dataset[idx]
        state = torch.get_rng_state()
        a_image = self.transformation(a_image)
        torch.set_rng_state(state)
        b_image = self.transformation(b_image)
        return a_image, b_image, delta

In [72]:
pin_memory = True
enable_half = device == torch.device("cuda")
scaler = GradScaler(device, enabled=enable_half)

In [73]:
transformations = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(180),
    T.RandomAffine(0, translate=(0.1, 0.1)),
    T.RandomAffine(0, scale=(0.9, 1.1)),
    T.RandomAffine(0, shear=15),
    T.ColorJitter(0.2, 0.2, 0.2, 0.2),
    T.RandomErasing(),
])

train_dataset = TransformationDataset(SplitDataset(dataset, 0, 0.7), transformations)
validate_dataset = SplitDataset(dataset, 0.7, 0.85)
test_dataset = SplitDataset(dataset, 0.85, 1)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=pin_memory)
validate_dataloader = DataLoader(validate_dataset, batch_size=32, num_workers=4, pin_memory=pin_memory)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=4, pin_memory=pin_memory)

In [74]:
class SimpleModel(nn.Module):
    def __init__(self, image_size=128, latent_dim=256, time_dim=1) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # 64x64
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 32x32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 16x16
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 8x8
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * (image_size // 16) ** 2, latent_dim)
        )

        self.time_embedding = nn.Linear(time_dim, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * 2, 128 * (image_size // 16) ** 2),
            nn.ReLU(),
            nn.Unflatten(1, (128, image_size // 16, image_size // 16)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x, time_skip):
        x_latent = self.encoder(x)
        time_embedding = self.time_embedding(time_skip.unsqueeze(1))
        combined_latent = torch.cat((x_latent, time_embedding), dim=1)
        x_decoded = self.decoder(combined_latent)
        return x_decoded

In [75]:
model = SimpleModel().to(device)
model = torch.jit.script(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0, fused=True)

In [76]:
def train():
    model.train()

    total_loss = 0

    for a_image, b_image, time_skip in train_dataloader:
        a, b = a_image.to(device, non_blocking=True), b_image.to(device, non_blocking=True)
        time_skip = time_skip.to(device, non_blocking=True).to(torch.float32)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(a, time_skip)
            loss = criterion(outputs, b)
            total_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    return total_loss / len(train_dataloader)

train()

KeyboardInterrupt: 

In [54]:
def val():
    model.eval()
    total_loss = 0

    for a_image, b_image, time_skip in validate_dataloader:
        a, b = a_image.to(device, non_blocking=True), b_image.to(device, non_blocking=True)
        time_skip = time_skip.to(device, non_blocking=True).to(torch.float32)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(a, time_skip)
            loss = criterion(outputs, b)
            total_loss += loss.item()

    return total_loss / len(validate_dataloader)

val()

KeyboardInterrupt: 

In [None]:
def test():
    model.eval()
    total_loss = 0

    for a_image, b_image, time_skip in test_dataloader:
        a, b = a_image.to(device, non_blocking=True), b_image.to(device, non_blocking=True)
        time_skip = time_skip.to(device, non_blocking=True).to(torch.float32)
        with torch.autocast(device.type, enabled=enable_half):
            outputs = model(a, time_skip)
            loss = criterion(outputs, b)
            total_loss += loss.item()

    return total_loss / len(test_dataloader)

test()

In [None]:
def run(n):
    for _ in tqdm.trange(n):
        train_loss = train()
        test_loss = test()
        val_loss = val()
        print(f"Train loss: {train_loss}, Val loss: {val_loss}, Test loss: {test_loss}")