In [None]:
import os
import urllib.request
from urllib.error import HTTPError
#plotting tools
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
matplotlib.rcParams['lines.linewidth'] = 2.0
sns.reset_orig()
sns.set()
from PIL import Image
from tqdm.notebook import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torchvision 
from torchvision.datasets import CIFAR10, CocoDetection
from torchvision import transforms
from pycocotools.coco import COCO

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

DATASET_PATH = "data"
CHECKPOINT_PATH = "saved_models/Autoencoder/"

pl.seed_everything(77)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0")

PATH_TO_DATA = "data/COCO/train2017"
PATH_TO_ANN = "data/COCO/annotations/instances_train2017.json"
PATH_TO_BGBBOX = "data/COCO/annotations/coco_train_bg_bboxes.log"

print("Device: ",device)

In [None]:
class CocoClsDataset(data.Dataset):
    def __init__(self, img_dir, ann_file, bg_bboxes_file):
        self.ann_file = PATH_TO_ANN
        self.img_dir = PATH_TO_DATA
        self.coco = COCO(self.ann_file)
        self.bg_bboxes_file = bg_bboxes_file
        self.transform = transforms.Compose([transforms.Resize((256, 256)),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,))])
        cat_ids = self.coco.getCatIds()
        categories = self.coco.dataset['categories']
        self.id2cat = dict()
        for category in categories:
            self.id2cat[category['id']] = category['name']
        self.id2cat[0] = 'background'
        self.id2label = {category['id']:label + 1 for label, category in enumerate(categories)}
        self.id2label[0] = 0
        self.label2id = {v:k for v,k in self.id2label.items()}
        tmp_ann_ids = self.coco.getAnnIds()
        self.ann_ids = []
        for ann_id in tmp_ann_ids:
            ann = self.coco.loadAnns([ann_id])[0]
            x, y, w, h = ann['bbox']
            x, y, w, h = int(x), int(y), int(w), int(h)
            if ann['area'] <= 0 or w < 1 or h < 1 or ann['iscrowd']:
                continue
            self.ann_ids.append(ann_id)
        self.bg_anns = self._load_bg_anns()
        self._cal_num_dict()
        print('total_length of dataset:', len(self))
        
    def _cal_num_dict(self):
        self.num_dict = {}
        for ann_id in self.ann_ids:
            ann = self.coco.loadAnns([ann_id])[0]
            cat = self.id2cat[ann['category_id']]
            num = self.num_dict.get(cat, 0)
            self.num_dict[cat] = num + 1
        self.num_dict['background'] = len(self.bg_anns)
    
    def _load_bg_anns(self):
        assert os.path.exists(self.bg_bboxes_file)
        bg_anns = []
        with open(self.bg_bboxes_file, 'r') as f:
            line = f.readline()
            while line:
                if line.strip() == '':
                    break
                file_name, num = line.strip().split()
                for _ in range(int(num)):
                    bbox = f.readline()
                    bbox = bbox.strip().split()
                    bbox = [float(i) for i in bbox]
                    w = bbox[2] - bbox[0] + 1
                    h = bbox[3] - bbox[1] + 1
                    bbox[2], bbox[3] = w, h
                    ann = dict(
                        file_name=file_name,
                        bbox=bbox)
                    bg_anns.append(ann)
                line = f.readline()
        return bg_anns
    
    def __len__(self):
        return len(self.ann_ids) + len(self.bg_anns)


    def __getitem__(self, idx):
        if idx < len(self.ann_ids):
            ann = self.coco.loadAnns([self.ann_ids[idx]])[0]

            cat_id = ann['category_id']
            label = self.id2label[cat_id]

            img_meta = self.coco.loadImgs(ann['image_id'])[0]
            img_path = os.path.join(self.img_dir, img_meta['file_name'])
        else:
            ann = self.bg_anns[idx - len(self.ann_ids)]

            label = 0

            img_path = os.path.join(self.img_dir, ann['file_name'])

        img = Image.open(img_path).convert('RGB')
        x, y, w, h = ann['bbox']
        x, y, w, h = int(x), int(y), int(w), int(h)
        img = img.crop((x, y, x + w - 1, y + h - 1))

        # save_img = img.resize((224, 224), Image.BILINEAR)
        # save_img.save('test.jpg')

        try:
            img = self.transform(img)
        except:
            print(img.mode)
            exit(0)
        if label != 0:
            label = 1
        tmp_label = torch.zeros(2)
        tmp_label[label] = 1
        return img, tmp_label

In [None]:
coco_dataset = CocoClsDataset(img_dir=PATH_TO_DATA, # takes at least 10 sec to execute
                          ann_file=PATH_TO_ANN,
                          bg_bboxes_file=PATH_TO_BGBBOX)

In [None]:
images, labels = coco_dataset.__getitem__(130000)
plt.imshow(images.permute(1, 2, 0))
print(labels)

In [None]:
len(coco_dataset)

In [None]:
train_data, _ = torch.utils.data.random_split(coco_dataset, [100000, 1599804])
train_set, val_set = torch.utils.data.random_split(train_data, [80000, 20000])
val_set, test_set = torch.utils.data.random_split(val_set, [15000, 5000])

train_loader = DataLoader(train_set, batch_size=100, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=100, shuffle=False, drop_last=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, drop_last=False, num_workers=4)

def get_train_images(num):
    return torch.stack([train_data[i][0] for i in range(num)], dim=0)

In [None]:
len(train_set)

In [None]:
len(val_set)

In [None]:
len(test_set)

In [None]:
print(val_set[0][0].shape)

In [None]:

class Encoder(nn.Module):
    
    def __init__(self, 
                num_input_channels : int, 
                base_channel_size : int, 
                latent_dim : int, 
                act_fn : object = nn.GELU):
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 256x256 ==> 128x128
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), 
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 128x128 ==> 64x64
            act_fn(), 
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), # 
            act_fn(), 
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 64x64 ==> 32x32
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),#
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 ==> 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), # dimension won't change
            act_fn(), 
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 ==> 8x8
            act_fn(), 
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1), #
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 ==> 4x4
            act_fn(),
            nn.Flatten(), 
            nn.Linear(2*16*c_hid, latent_dim)
        )
        
    def forward(self,x):
        return self.net(x)

In [None]:

class Decoder(nn.Module):
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid), 
            act_fn()
        )
        
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4==>8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1), 
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 ==> 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(), 
            nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 ==> 32x32
            act_fn(), 
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1), 
            act_fn(), 
            nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 32x32 ==> 64x64
            act_fn(), 
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), #64 ==> 128
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), #128-->256
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [None]:
class Autoencoder(pl.LightningModule):
    def __init__(self, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder, 
                 num_input_channels : int = 3,
                 width : int = 256, 
                 height : int = 256
        ):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
    
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
    
    def _get_reconstruction_loss(self, batch, mode="mse"):
        x, _ = batch
        x_hat = self.forward(x)
        if mode == "mse":
            loss = F.mse_loss(x, x_hat, reduction="none")
        else:
            loss = F.kl_div(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode = 'min', 
                                                         factor=0.2, 
                                                         patience=20, 
                                                         min_lr=5e-5)
        return {"optimizer":optimizer, "lr_scheduler":scheduler, "monitor":"val_loss"}
    
    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('test_loss', loss)
        return loss

In [None]:
def compare_imgs(img1, img2, title_prefix=""):
    loss = F.mse_loss(img1, img2, reduction="sum")
    grid = torchvision.utils.make_grid(torch.stack([img1, img2], dim=0), nrow=2, normalize=True, range=(-1,1))
    grid = grid.permute(1,2,0)
    plt.figure(figsize=(4,2))
    plt.title(f"{title_prefix} Loss : {loss.item():4.2f}")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

for i in range(2):
    # Load example image
    img, _ = train_data[i]
    img_mean = img.mean(dim=[1,2], keepdims=True)

    # Shift image by one pixel
    SHIFT = 1
    img_shifted = torch.roll(img, shifts=SHIFT, dims=1)
    img_shifted = torch.roll(img_shifted, shifts=SHIFT, dims=2)
    img_shifted[:,:1,:] = img_mean
    img_shifted[:,:,:1] = img_mean
    compare_imgs(img, img_shifted, "Shifted -")

    # Set half of the image to zero
    img_masked = img.clone()
    img_masked[:,:img_masked.shape[1]//2,:] = img_mean
    compare_imgs(img, img_masked, "Masked -")

In [None]:
class GenerateCallback(pl.Callback):
    
    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs
        self.every_n_epochs = every_n_epochs
        
    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            imgs = torch.stack([input_imgs, reconst_imgs], dim = 1).flatten(0,1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

In [None]:
def train_coco(latent_dim):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "coco80_%i" %latent_dim),
                         gpus=1, 
                         max_epochs=300, 
                         callbacks=[ModelCheckpoint(save_weights_only=True), 
                                    GenerateCallback(get_train_images(8), every_n_epochs=10),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric = None
    
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "coco80_%i.ckpt" % latent_dim)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading ...")
        model = Autoencoder.load_from_checkpoint(pretrained_filename)
    else:
        model = Autoencoder(base_channel_size=64, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    # Testing the model on validation data
    val_result = trainer.test(model, test_dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, test_dataloaders=test_loader, verbose=False)
    result = {"test":test_result, "val":val_result}
    return model, result

In [None]:
model_dict = {}
for latent_dim in [384,512]:
    model_ld, result_ld = train_coco(latent_dim)
    model_dict[latent_dim] = {"model": model_ld, "result": result_ld} 