In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from ae_model import AutoEncoder
import wandb
from tqdm import tqdm

In [2]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCH = 100
align = False
align_weight = 0.1
LR = 1e-3

In [4]:
wandb.init(
      project="rep_learning",
      config={
      "learning_rate": LR,
      "epochs": 100,
      "align": align
      })

[34m[1mwandb[0m: Currently logged in as: [33madi30502[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
base_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image
])

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomCrop(32, padding=4),  # Randomly crop with padding  # Normalize the image
])

# Download and load the training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=base_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

# Download and load the test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=base_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

model = AutoEncoder([3,128,256,512],True)
model.to(DEVICE)

if align:
    align_criterion = lambda a, b: torch.sum((a-b)**2)/a.shape[0]
else:
    align_criterion = lambda a, b: 0

optimiser = torch.optim.Adam(model.parameters(),lr = LR)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 21.0MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
for i in tqdm(range(EPOCH)):
    for batch,(x,y) in enumerate(trainloader):
        x = x.to(DEVICE)
        with torch.no_grad():
            x_transform = transform(x)

        emb, x_pred = model(x)
        emb_transform,x_pred_transform = model(x_transform)
        loss_recon = torch.sum((x_pred-x)**2)/(2*x.shape[0]) + torch.sum((x_pred_transform-x_transform)**2)/(2*x_transform.shape[0])
        loss_align = align_criterion(emb,emb_transform)

        loss = loss_recon+align_weight*loss_align

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        wandb.log({"train/recon_loss": loss_recon,"train/align_loss": loss_align})

    if i%5 == 0:
        for batch,(x,y) in enumerate(testloader):
          with torch.no_grad():
            x = x.to(DEVICE)
            x_transform = transform(x)

            emb, x_pred = model(x)
            emb_transform,x_pred_transform = model(x_transform)

            loss_recon = torch.sum((x_pred-x)**2)/(2*x.shape[0]) + torch.sum((x_pred_transform-x_transform)**2)/(2*x_transform.shape[0])
            loss_align = align_criterion(emb,emb_transform)

            loss = loss_recon+align_weight*loss_align

            wandb.log({"test/recon_loss": loss_recon,"test/align_loss": loss_align})
            break

  0%|          | 0/100 [00:00<?, ?it/s]