In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch.nn.modules.loss import MSELoss
from torch.utils.data import dataloader
import utils
from arguments import parse_args
import torch.optim as optim
import torch.utils.data as data

import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
import matplotlib.pyplot as plt
import tqdm

import wandb
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
import os
from sampling import DelaunayTriangulationBlur
from PIL import Image
import random
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
import numba as nb
from torchvision.transforms import ToTensor, ToPILImage

from IPython.display import clear_output as clear

%matplotlib inline
%load_ext autoreload
%autoreload 2

class TrainDataset_PictureOnly(data.Dataset):
    def __init__(self, args):
        self.args = args
        self.root_path = os.path.join(args.data_root, 'train')
        self.img_list = os.listdir(self.root_path)
        try:
            self.img_list.remove('.DS_Store')
        except:
            pass
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_path, self.img_list[idx])
        img_raw = Image.open(img_path)
        img = img_raw.resize((256, 256))
        img = ToTensor()(img)
        return img
    
    def __len__(self):
        return len(self.img_list)

In [None]:
args = parse_args(["--alg", "MAE", "--description", "mae_pretrain_piconly", "--lr", "1e-3", "--epoch", "100", '--data_root', 'data/celeba', '--batch_size', '1'])

In [None]:
wandb.init(project="MAE-toy-modified-celebA-1000-12131124", entity="purewhite2019")
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 100,
  "batch_size": 1,
  "seed" : 31415926
}

# Training Preparation
utils.set_seed_everywhere(31415926)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

train_dataset = TrainDataset_PictureOnly(args)
train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

loss_function = nn.MSELoss()

# loss_function = lambda img_gt, img_pred : -utils.psnr(img_gt, img_pred) - utils.ssim(img_gt, img_pred)

from models.vit import ViT
from models.mae import MAE

img_size, patch_size = (256, 256), (16, 16)

# encoder = ViT(img_size, patch_size, depth=12, dim=768, mlp_dim=3072, num_heads=12) # ViT-B/16
# encoder = ViT(img_size, patch_size, depth=24, dim=1024, mlp_dim=4096, num_heads=16) # ViT-L/16 (Default in MAE paper)
# encoder = ViT(img_size, patch_size, depth=32, dim=1280, mlp_dim=5120, num_heads=16) # ViT-H/16
encoder = ViT(img_size, patch_size, depth=6, dim=512, mlp_dim=1024, num_heads=8) # Simple

# model = MAE(encoder, decoder_depth=8, decoder_dim=512, mask_ratio=0.75) # (Default in MAE paper)
model = MAE(encoder, decoder_depth=6, decoder_dim=512, mask_ratio=0.75)
model.to(device)

optimizer = optim.RAdam(params=model.parameters(),lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1000,gamma = 0.5)

e = 0

In [None]:
print(len(train_loader))

In [None]:
import random
# Start training
lr_list = []
model.train()
wandb.watch(model)

# for _ in range(800): # Default setting in MAE paper
for _ in range(args.epoch):
    idx = random.randint(0, len(train_dataset)-1)
    img = train_dataset[idx].unsqueeze(0).to(device)
    utils.save_model(model, e, args)
    clear()
    recons_img, patches_to_img = model.predict(img)
    recons_img = recons_img[0].permute(1, 2, 0).cpu().numpy()
    patches_to_img = patches_to_img[0].permute(1, 2, 0).cpu().numpy()
    img_gt = img[0].permute(1, 2, 0).cpu().numpy()
    utils.show_gt_and_pred(img_hr=patches_to_img, img_lr=recons_img, pred_hr=img_gt, figsize=(30, 30))
    
    loop = tqdm.tqdm(train_loader)
    for idx, img in enumerate(loop):
        
        img = img.to(device)
        loss = model(img)
            
        wandb.log({"loss": loss})
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        
        loop.set_description(f"epoch: {e} | iter: {idx}/{len(train_dataset)} | loss: {loss.item()}")
    e += 1