In [1]:
import sys; sys.path.append('..')
import torch
from torchsummary import summary

DEVICE = 'cuda'

In [2]:
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor

ds_train = CIFAR100('../data/CIFAR100', train=True, transform=ToTensor())
ds_test = CIFAR100('../data/CIFAR100', train=False, transform=ToTensor())

In [3]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset

# ds_train = Subset(ds_train, range(100))
# ds_test = Subset(ds_test, range(100))

NUM_CLASSES = 100

dl_train = DataLoader(ds_train, batch_size=256, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=256, shuffle=False)

In [4]:
import torch
import torch.nn as nn
from torch import Tensor
from firelab.config import Config

from src.models.layers import ConvBNReLU, ConvTransposeBNReLU, Reshape


class AutoEncoder(nn.Module):
    def __init__(self, config: Config):
        super(AutoEncoder, self).__init__()

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

    def forward(self, x: Tensor) -> Tensor:
        return self.decoder(self.encoder(x))


class Encoder(nn.Module):
    def __init__(self, config: Config):
        super(Encoder, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, stride=2),
            nn.ReLU(),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)
    

class Decoder(nn.Module):
    def __init__(self, config: Config):
        super(Decoder, self).__init__()

        self.model = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2),  # [batch, 24, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)


class ShapePrinter(nn.Module):
    def __init__(self, title: str='unknown'):
        super(ShapePrinter, self).__init__()
        self.title = title

    def forward(self, x):
        print(self.title, x.shape, x.flatten(1).shape)
        return x

In [5]:
# ae = AutoEncoder(None).to('cuda')
# summary(ae, (3, 32, 32))

In [None]:
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np

ae = AutoEncoder(None).to(DEVICE)
optim = torch.optim.Adam(ae.parameters())
max_num_epochs = 1000


def validate_ae(model, dataloader):
    all_losses = []
    model.eval()
    
    with torch.no_grad():
        #for x, _ in tqdm(dataloader, desc='Validating', total=len(dataloader)):
        for x, _ in dataloader:
            x = x.to(DEVICE)
            #losses = F.binary_cross_entropy(model(x), x, reduction='none')
            losses = F.mse_loss(model(x), x, reduction='none')
            all_losses.extend(losses.cpu().tolist())
            
    return np.mean(all_losses)


for epoch in tqdm(range(1, max_num_epochs + 1)):
    #for i, (x, _) in tqdm(enumerate(dl_train), desc='Training', total=len(dl_train)):
    for i, (x, _) in enumerate(dl_train):
        ae.train()
        x = x.to(DEVICE)
        
        #loss = F.binary_cross_entropy(ae(x), x)
        loss = F.mse_loss(ae(x), x)
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    if epoch % 100 == 0:
        print(f'[train] Epoch {epoch}. {validate_ae(ae, dl_train)}')
        print(f'[val] Epoch {epoch}: {validate_ae(ae, dl_test)}')

  2%|▏         | 21/1000 [03:38<2:48:31, 10.33s/it]

In [None]:
import random
import matplotlib.pyplot as plt

ae.eval()

with torch.no_grad():
    imgs = torch.tensor([ds_test[i][0].tolist() for i in random.sample(range(len(ds_test)), 10)])
    recs = ae(imgs.to(DEVICE)).cpu()

    imgs = imgs.permute(0, 2, 3, 1).numpy()
    recs = recs.permute(0, 2, 3, 1).numpy()

In [None]:
plt.figure(figsize=(20, 5))
plt.subplot(211)
plt.title('[val] Ground Truth')
plt.imshow(np.stack(imgs, axis=1).reshape(32, -1, 3), interpolation='nearest')

plt.subplot(212)
plt.title('[val] Reconstructions')
plt.imshow(np.stack(recs, axis=1).reshape(32, -1, 3), interpolation='nearest')

In [None]:
def extract_features(ae, dataloader):
    ae.eval()
    
    y = []
    feats = []
    
    with torch.no_grad():
        for batch in dataloader:
            feats.extend(ae.encoder(batch[0].to(DEVICE)).cpu().flatten(1).tolist())
            y.extend(batch[1])
            
    return feats, y

feats_train, y_train = extract_features(ae, dl_train)
feats_test, y_test = extract_features(ae, dl_test)

In [None]:
feats_dl_train = DataLoader(list(zip(*[feats_train, y_train])), batch_size=256, shuffle=True, collate_fn=lambda b: list(zip(*b)))
feats_dl_test = DataLoader(list(zip(*[feats_test, y_test])), batch_size=256, shuffle=False, collate_fn=lambda b: list(zip(*b)))

In [None]:
ds = list(zip(*[feats_train, y_train]))

In [None]:
def validate_clf(clf_model, dataloader):
    losses = []
    accs = []

    with torch.no_grad():
        for x, y in dataloader:
            x = torch.tensor(x).to(DEVICE)
            y = torch.tensor(y).to(DEVICE)

            logits = clf_model(x)
            loss = F.cross_entropy(logits, y, reduction='none').cpu().tolist()
            acc = (logits.argmax(dim=1) == y).float().cpu().tolist()
            
            losses.extend(loss)
            accs.extend(acc)
        
    return np.mean(losses), np.mean(accs)

In [None]:
classifier = nn.Sequential(
    #nn.Dropout(0.1),
    #nn.Linear(2048, 256),
    #nn.BatchNorm1d(256),
    #nn.ReLU(),
    #nn.Dropout(0.1),
    #nn.Linear(256, 100)
    nn.Linear(2048, 100)
).to(DEVICE)

max_num_epochs = 100
optim = torch.optim.Adam(classifier.parameters(), lr=1e-2)

for epoch in range(max_num_epochs):
    for batch in feats_dl_train:
        x = torch.tensor(batch[0]).to(DEVICE)
        y = torch.tensor(batch[1]).to(DEVICE)

        logits = classifier(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean().detach().cpu()

        optim.zero_grad()
        loss.backward()
        optim.step()

    #print('acc', acc.item())
    if epoch % 5 == 0:
        print('train acc:', validate_clf(classifier, feats_dl_train)[1])
        print('val acc:', validate_clf(classifier, feats_dl_test)[1])