In [1]:
#@title 256 dims
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tensorboard import notebook
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the hyperparameters
batch_size = 32
num_epochs = 100
learning_rate = 0.0005
weight_decay = 0.05  # Increased weight decay
save_interval = 1
accumulation_steps = 2  # Gradient accumulation steps
max_grad_norm = 5.0  # Maximum gradient norm for gradient clipping
save_dir = '/content/drive/MyDrive/AutoEncoder-Saves'  # Directory to save checkpoints

# Create the save directory if it does not exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Data transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [23]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch

transform = transforms.Compose([
    transforms.Resize((64, 64)),    # Resize images to 128x128
    transforms.ToTensor(),            # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])
class marioDataset(Dataset):
    def __init__(self, path_dir, num_frames=4, transform=None):
        self.dir = path_dir
        self.num_frames = num_frames
        self.transform = transform
        self.image_files = [f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))]

    def __len__(self):
        return len(self.image_files)//self.num_frames 
    
    def __getitem__(self, idx):
        # Load the image from the file
        images = []
        for i in range(self.num_frames):
            img_path = os.path.join(self.dir, self.image_files[idx + i])
            with Image.open(img_path).convert("RGB") as image:
                image = self.transform(image)
                images.append(image)
        images = torch.stack(images)
        return images
    
dataset = marioDataset(path_dir='data/mario/',num_frames=1, transform=transform)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
dataset[0].shape

torch.Size([1, 3, 64, 64])

In [15]:
# Define the autoencoder model with residual connections
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, momentum=0.1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
        self.skip_connection = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels, momentum=momentum)
            )

    def forward(self, x):
        residual = self.skip_connection(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class Autoencoder(nn.Module):
    def __init__(self, latent_dim=64):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            ResidualBlock(3, 64, stride=2),
            ResidualBlock(64, 128, stride=2)
        )
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 16 * 16, latent_dim)
        self.fc2 = nn.Linear(latent_dim, 128* 16 * 16)
        self.unflatten = nn.Unflatten(1, (128, 16, 16))
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            ResidualBlock(128, 64, stride=1),
            nn.Upsample(scale_factor=2, mode='nearest'),
            ResidualBlock(64, 64, stride=1),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        enc = self.encoder(x)
        latent_space = self.flatten(enc)
        latent_space = self.fc1(latent_space)
        x = self.fc2(latent_space)
        x = self.unflatten(x)
        x = self.decoder(x)
        return x, latent_space

# Perceptual loss using a pre-trained VGG network
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = torchvision.models.vgg16(pretrained=True).features
        self.slice1 = nn.Sequential(*list(vgg.children())[:4]).eval()
        for param in self.slice1.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x_vgg, y_vgg = self.slice1(x), self.slice1(y)
        loss = nn.functional.l1_loss(x_vgg, y_vgg)
        return loss


In [16]:
def save_model(epoch, model):
    model_path = os.path.join(save_dir, f'model_epoch_{epoch + 1}.pth')
    torch.save(model.state_dict(), model_path)
    print(f'Model saved at epoch {epoch + 1}')

def load_model(model_path):
    model = Autoencoder()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    print(f"Model loaded from {model_path}")
    return model

# Laplacian filter to extract high-frequency components
laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
laplacian_kernel = laplacian_kernel.to(device)

def high_pass_filter(img):
    batch_size, channels, height, width = img.size()
    filtered_img = torch.empty_like(img)
    for i in range(channels):
        filtered_img[:, i:i+1, :, :] = F.conv2d(img[:, i:i+1, :, :], laplacian_kernel, padding=1)
    return filtered_img

def check_for_nans(tensor, name):
    if torch.isnan(tensor).any():
        print(f'NaNs found in {name}')
        return True
    return False


In [40]:
import matplotlib.pyplot as plt

# Training loop with high-frequency emphasis
dataloader = data_loader
def train_autoencoder(model, start_epoch=0, num_epochs=100):
    criterion = nn.L1Loss()
    perceptual_loss = PerceptualLoss().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scaler = GradScaler()
    writer = SummaryWriter()

    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()

        for i, inputs in enumerate(dataloader):
            inputs = inputs.to(device)
            inputs = inputs.view(-1, 3, 64, 64).to(device)

            with autocast():
                reconstructed, _ = model(inputs)
                if check_for_nans(reconstructed, 'reconstructed'):
                    continue
                high_freq_inputs = high_pass_filter(inputs)
                high_freq_reconstructed = high_pass_filter(reconstructed)
                loss = criterion(reconstructed, inputs) + perceptual_loss(reconstructed, inputs)
                high_freq_loss = criterion(high_freq_reconstructed, high_freq_inputs)
                loss += high_freq_loss
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps

            if i % 10 == 0:
                print(f'Epoch {epoch + 1}, Step {i + 1}/{len(dataloader)}, Loss: {loss.item() * accumulation_steps}')
                writer.add_scalar('Loss/train', loss.item() * accumulation_steps, epoch * len(dataloader) + i)

        avg_loss = running_loss / len(dataloader)
        writer.add_scalar('Loss/epoch', avg_loss, epoch)
        print(f'Epoch {epoch + 1} finished with avg loss: {avg_loss}')

        model.eval()
        with torch.no_grad():
            random_idx = np.random.randint(0, len(dataset))
            random_image = dataset[random_idx].to(device)
            reconstructed_image, latent_space = model(random_image)
            plt.imshow(reconstructed_image.squeeze(0).transpose(2,0).transpose(0,1).cpu())
            plt.show()
            writer.add_image('Original Image', (random_image.squeeze(0) + 1) / 2, epoch)
            writer.add_image('Reconstructed Image', (reconstructed_image.squeeze(0) + 1) / 2, epoch)

        if (epoch + 1) % save_interval == 0:
            save_model(epoch, model)

    writer.close()


In [41]:
# Initialize the model
model = Autoencoder().to(device)

In [43]:
# Start TensorBoard
%load_ext tensorboard
%tensorboard --logdir=runs

train_autoencoder(model, start_epoch=20, num_epochs=100)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 6564), started 3:29:57 ago. (Use '!kill 6564' to kill it.)

In [39]:
!kill 6564

kill: 6564: No such process
