In [None]:
%matplotlib notebook
import notebook

import torchsummary
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
from types import SimpleNamespace
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

import utils
from utils import load_train, AE, ResBlock2D, View, Flatten, as_shape
from SSIM import SSIM # structural similarity loss...

import jnu as J

config = dict(
    device = "cuda:0",
    state_shape = (3,84,84),
    latent_shape = (256,),
    batch_size = 256,
    learning_rate = 0.0005,
    epochs = 10
)
config = SimpleNamespace(**config)

In [None]:
model = AE(config.state_shape, config.latent_shape).to(config.device)
#model.load_state_dict(torch.load("./AE.pt"))
optim = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
#criterion = nn.BCEWithLogitsLoss()
criterion = SSIM()
#print(torchsummary.summary(model, config.state_shape))

In [None]:
dataset = ConcatDataset([TensorDataset(torch.from_numpy(ep[0]).to(config.device)) for ep in utils.load_train(keys=["observation"])])
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)

In [None]:
loss_msg = "Loss: {0}"
msg = J.HTML(loss_msg.format(0))
display(msg)
img = J.image(next(iter(loader))[0][0])

config.epochs = 1000
for e in tqdm(range(config.epochs)):
    for batch_obs, in loader:
        #batch_obs = batch_obs.to(config.device)
        optim.zero_grad()
        y = model(batch_obs)
        loss = criterion(y, batch_obs)
        loss.backward()
        optim.step()
        
    msg.value = loss_msg.format(loss.item())
    #img.update(torch.sigmoid(y[0]))
    img.update(y[0])


In [None]:

with torch.no_grad():
    loader = DataLoader(dataset, batch_size=1024, shuffle=False, drop_last=True)
    batch_obs = next(iter(loader))[0]
    imgs = torch.cat([batch_obs, torch.clip(model(batch_obs),0,1)], dim=3).cpu().numpy()
    J.images(imgs, scale=3)
    
    
        

In [None]:
torch.save(model.state_dict(), "./AE-SSIM-256.pt")

In [None]:
test_data = [(torch.from_numpy(ep[0]).to(config.device), torch.from_numpy(ep[1])) for ep in utils.load_test(keys=["observation", "bugmask"])]


In [None]:
model = AE(config.state_shape)
model.load_state_dict(torch.load("./AE.pt"))
criterion = nn.BCEWithLogitsLoss(reduction='none')
with torch.no_grad():
    n = 1024
    model = model.cpu()
    for obs, mask in test_data:
        obs, mask = obs[:n], mask[:n]
        obs = obs.cpu()
        pred = model(obs).cpu()
        score = criterion(pred, obs)
        
        score = score.reshape(score.shape[0],-1).sum(-1).cpu().numpy()
        score = np.interp(score, (score.min(), score.max()), (0, +1))
        
        label = mask.reshape(mask.shape[0],-1).sum(-1).cpu().numpy()
        label = np.interp(label, (label.min(), label.max()), (0, +1))
        
        
        J.images(torch.cat([torch.sigmoid(pred), obs.cpu(), mask], dim=3))
        fig = plt.figure(figsize=(10,5))
        plt.plot(np.arange(score.shape[0]), score, label="score")
        plt.plot(np.arange(score.shape[0]), label, label="label")
        plt.legend()

In [None]:
with torch.no_grad():
    for data in test_data:
        
    batch_obs = next(iter(loader))[0]
    imgs = torch.cat([batch_obs, torch.sigmoid(model(batch_obs))], dim=3).cpu().numpy()
    J.images(imgs, scale=3)