In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
from model import UNET

## Data

In [None]:
dataset_train = datasets.MNIST(root='../gans/', train=True)
dataset_test = datasets.MNIST(root='../gans/', train=False)

In [None]:
data_train = dataset_train.data.float()
data_test = dataset_test.data.float()
scale = data_train.max()

In [None]:
# pseudo "poisson" directly in small range (more noise)
noise_lambda = 0.5
noise_train = np.random.normal(noise_lambda, noise_lambda**2, data_train.shape).astype(np.float32)
noise_test = np.random.normal(noise_lambda, noise_lambda**2, data_test.shape).astype(np.float32)

In [None]:
# pseudo "poisson" transform to small range (less noise)
#noise_mu = 0.5
#noise_sigma = noise_mu * (scale/2) ** -0.5 
#noise_train = np.random.normal(noise_mu, noise_sigma, data_train.shape).astype(np.float32)
#noise_test = np.random.normal(noise_mu, noise_sigma, data_test.shape).astype(np.float32)

In [None]:
x_train = data_train / scale + noise_train
y_train = data_train > scale/2
x_test = data_test / scale + noise_test
y_test = data_test > scale/2

In [None]:
ind = 0
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].imshow(data_train[ind])
axs[1].imshow(x_train[ind])

In [None]:
train = torch.cat([x_train.unsqueeze(1), y_train.unsqueeze(1)], dim=1)
test = torch.cat([x_test.unsqueeze(1), y_test.unsqueeze(1)], dim=1)

In [None]:
train_loader = DataLoader(train, batch_size=128, shuffle=True, num_workers=0)
test_loader = DataLoader(test, batch_size=128, shuffle=True, num_workers=0)

## Model and Hyperparameters

In [None]:
model = UNET()

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device);

In [None]:
criterion = nn.BCELoss(reduction='sum')

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
name_params = []
for name, parameters in model.named_parameters():
    name_params.append([name, parameters.numel()])
df_np = pd.DataFrame(name_params, columns=['name', 'parameters'])
df_np['parameters'].sum()

## Train

In [None]:
def train_loop():
    # Set model to training mode
    model.train()
    # Train through batches
    loss_total = 0
    for i, dt in enumerate(train_loader):
        # Put data and target on devices
        data = dt[:, 0:1, :, :].to(device)
        target = dt[:, 1:, :, :].to(device)
        
        # Find loss
        pred = model(data)
        loss = criterion(pred, target)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_total += loss.item()

    loss_train = loss_total / train_loader.dataset.shape[0]
    return loss_train

In [None]:
def test_loop():
    # Set model to evaluation mode
    model.eval()
    # Evaluate through batches
    loss_total = 0
    # Do not calculate gradients
    with torch.no_grad():
        for i, dt in enumerate(test_loader):
            # Put data and target on devices
            data = dt[:, 0:1, :, :].to(device)
            target = dt[:, 1:, :, :].to(device)
            
            # Find loss
            pred = model(data)
            loss = criterion(pred, target)
            loss_total += loss.item()

    loss_test = loss_total / test_loader.dataset.shape[0]
    return loss_test

In [None]:
loss = []
epochs = 10
for epoch in tqdm(range(epochs), total=epochs):
    loss_train = train_loop()
    loss_test = test_loop()
    loss.append([loss_train, loss_test])
    print (f'Epoch:{epoch:.3f} - Train loss: {loss_train:.3f} - Test loss: {loss_test:.3f}')
loss = np.array(loss)

## Evaluate

In [None]:
plt.plot(loss[:, 0])
plt.plot(loss[:, 1])
plt.yscale('log')

In [None]:
ind = 1000
sample = x_test[ind].unsqueeze(0).unsqueeze(1).to(device)
pred = model(sample).cpu().detach().numpy()[0,0]
fig, axs = plt.subplots(1,3,figsize=[15,5])
axs[0].imshow(x_test[ind])
axs[1].imshow(y_test[ind])
axs[2].imshow(pred)

## Conv filters

In [None]:
out1 = model.mp(model.block1(sample))

In [None]:
n = 6
fig, axs = plt.subplots(n,n,figsize=[10,10],sharex=True,sharey=True)
for i in range (n):
    for j in range (n):
        if i*n+j < out1.shape[1]:
            axs[i,j].imshow(out1[0,i*n+j].cpu().detach())
plt.tight_layout()

In [None]:
plt.imshow(np.mean(out1.cpu().detach().numpy(), axis=(0,1)))

In [None]:
out2 = model.mp(model.block2(out1))

In [None]:
n = 8
fig, axs = plt.subplots(n,n,figsize=[10,10],sharex=True,sharey=True)
for i in range (n):
    for j in range (n):
        if i*n+j < out2.shape[1]:
            axs[i,j].imshow(out2[0,i*n+j].cpu().detach())
plt.tight_layout()

In [None]:
plt.imshow(np.mean(out2.cpu().detach().numpy(), axis=(0,1)))

In [None]:
out3 = model.mp(model.block3(out2))

In [None]:
n = 12
fig, axs = plt.subplots(n,n,figsize=[10,10],sharex=True,sharey=True)
for i in range (n):
    for j in range (n):
        if i*n+j < out3.shape[1]:
            axs[i,j].imshow(out3[0,i*n+j].cpu().detach())
plt.tight_layout()

In [None]:
plt.imshow(np.mean(out3.cpu().detach().numpy(), axis=(0,1)))