In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/CondNet')

In [None]:
from tqdm.notebook import tqdm

import importlib
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from torch import Tensor, Size
from torch import autograd
from torch.nn import Module
from typing import Optional
from torch.nn import functional as F

import fcn
import fcn8
import img_utils as iu
import loss_iter
import CondNet

In [None]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)

In [None]:
importlib.reload(iu)
iu.gen_images(10000)

In [None]:
class ReductionBlock(nn.Module):

    def __init__(self, in_channels):
        super(ReductionBlock, self).__init__()
        self.in_channels = in_channels
        self.net = nn.Sequential(
            nn.Conv2d(self.in_channels, 32, kernel_size=(4, 4)),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(32, 48, kernel_size=(4, 4)),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(48, 64, kernel_size=(4, 4)),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(64, 96, kernel_size=(4, 4)),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Upsample(scale_factor=64)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class MainBlock(nn.Module):
    
    def __init__(self, in_channels, kernel_size):
        super(MainBlock, self).__init__()
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.reduction_block = ReductionBlock(self.in_channels)
        self.batch_norm = nn.BatchNorm2d(self.in_channels)
        
        if self.kernel_size % 2 == 0:
            self.zero_pad = nn.ZeroPad2d(padding=(3, 0, 3, 0))
            self.conv_layer = nn.Conv2d(self.in_channels + 96, 64, kernel_size=self.kernel_size, padding=0)
        else:
            self.conv_layer = nn.Conv2d(self.in_channels + 96, 64, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
    def forward(self, x):
        x = self.batch_norm(x)
        reduced_x = self.reduction_block(x)
        # print(reduced_x.shape, x.shape)
        x = torch.cat((x, reduced_x), dim=1)
        if self.kernel_size % 2 == 0:
            x = self.zero_pad(x)
        x = self.conv_layer(x)
        x_mean = x.mean()
        ones = torch.ones(x.shape)
        if torch.cuda.is_available():
          ones = ones.cuda()
        x = torch.cat((x, x_mean * ones), dim=1)
        return x

In [None]:
class CondNet(nn.Module):

    def __init__(self, in_channels, n_class=3):
        super(CondNet, self).__init__()
        self.in_channels = in_channels
        self.n_class = n_class
        self.net = nn.Sequential(
            MainBlock(self.in_channels, 3),
            MainBlock(128, 4),
            MainBlock(128, 5)
        )

        self.cond_linear = nn.Sequential(
            nn.Linear(128, 64),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, 1)
        )

        self.cluster_linear = nn.Sequential(
            nn.Linear(128, 64),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, 2)
        )

        self.seg_linear = nn.Sequential(
            nn.Linear(128, 64),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, 3),
        )

        # self.seg_net = nn.Sequential(
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(128, 64, kernel_size=3, padding=1),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(64, 32, kernel_size=3, padding=1),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(32, self.n_class, kernel_size=1, padding=0),
        # )

    def forward(self, x):
        x = self.net(x)
        # print(x.shape)
        x = x.permute(0, 2, 3, 1)
        # print(x.shape, y.shape)
        beta = self.cond_linear(x).squeeze(3)
        clust = self.cluster_linear(x)
        seg = self.seg_linear(x).permute(0, 3, 1, 2)
        return beta, clust, seg

In [None]:
torch.manual_seed(0)
np.random.seed(0)

test_net = CondNet(3)
test_net.cuda()

opt1 = optim.Adam(test_net.parameters(), lr=5*10**(-4))

In [None]:
importlib.reload(loss_iter)
loss_func = loss_iter.CondLoss(loss_function=nn.CrossEntropyLoss(reduce=False, ignore_index=-1), q_min=0.4, supression=10, cond_weight=0.4, cuda=True)

In [None]:
ds = iu.ImageDataset('')

In [None]:
dataloader = torch.utils.data.DataLoader(ds, batch_size=100, shuffle=False)#, collate_fn=lambda batch: [(dp[0], dp[1], dp[2]) for dp in batch])

In [None]:
from matplotlib.lines import Line2D   

def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])

In [None]:


test_net.cuda()
test_net.train()
for i, (data, mask, matrix) in enumerate(dataloader):
  if (i < 9):
    print(f"i is {i}")
    data = data.cuda().float()
    mask = mask.cuda().squeeze(1)
    matrix = matrix.cuda().float()
    with autograd.detect_anomaly():
      opt1.zero_grad()
      beta, test_x, segmentation = test_net(data)
      test_matrix = matrix.reshape(matrix.shape[0], matrix.shape[1], -1)
      test_out = nn.Sigmoid()(beta)
      test_beta = test_out.reshape(test_out.shape[0], -1)
      print(torch.sum(test_beta < -1), torch.sum(test_beta > 1))
      test_q = loss_func.atanh(test_beta) ** 2 + loss_func.q_min
      print(f"q not nan: {torch.sum(torch.isnan(test_q))}")
      print(f"q max: {torch.max(test_q - loss_func.q_min, dim=1)}")
      test_noise = (torch.sum(test_matrix, dim=1) < 1).float()
      print(f"Noise vertices: {torch.sum(test_noise, dim=1)}")
      print(f"test_noise : {torch.sum((1 - test_noise) * (test_q - loss_func.q_min), dim=1)}")
      print(mask.shape, segmentation.shape)
      temp_general_loss = loss_func.general_loss(test_noise, test_q, segmentation, mask)
      temp_background_loss = loss_func.background_loss(test_beta, test_matrix, test_noise, test_matrix.shape[1])
      temp_potential_loss = loss_func.potential_loss(test_x.reshape(10, -1, 2), test_q, test_matrix, 64, 64, test_matrix.shape[1])
      print(f"General loss: {temp_general_loss}")
      print(f"Background loss: {temp_background_loss}")
      print(f"Potential loss: {temp_potential_loss}")
      print(f"Overall loss: {(temp_general_loss + loss_func.cond_weight * (temp_background_loss + temp_potential_loss)).mean()}")
      # print(test_noise.shape)
      # print(test_noise)
      # print(torch.sum(torch.eq(test_noise, 0)), torch.sum(torch.isnan(test_out)), torch.sum(torch.isnan(output[1])))
      if i < 8:
        curr_loss = loss_func(test_x, test_out, matrix, segmentation, mask)
        print(torch.sum(torch.isnan(curr_loss)))
        print(curr_loss.item())
        curr_loss.backward()
        opt1.step()

In [None]:
plot_grad_flow(test_net.named_parameters())

In [None]:
n_epochs = 20
# torch.autograd.set_detect_anomaly = True
loss_vs_epoch = []
test_net.float()
test_net.train()
sigmoid = nn.Sigmoid()
for epoch in tqdm(range(n_epochs)):
    temp_loss = 0
    for i, (data, mask, matrix) in enumerate(dataloader):
        opt1.zero_grad()
        data = data.cuda().float()
        mask = mask.cuda().squeeze(1)
        matrix = matrix.cuda().float()
        try:
          with autograd.detect_anomaly():
            output = test_net(data)
            curr_loss = loss_func(output[1], sigmoid(output[0]), matrix, output[2], mask)
            curr_loss.backward()
            opt1.step()
        except Exception as e:
          print(i, curr_loss)
          print(e)
          raise
        temp_loss += curr_loss.item() * data.size(0)
    loss_vs_epoch.append(temp_loss / len(dataloader.sampler))
    print(f"Epoch: {epoch}, Loss: {loss_vs_epoch[-1]}")
    torch.save(test_net.state_dict(), 'net_my_fcn.pt')

torch.save(test_net.state_dict(), '/content/drive/My Drive/Colab Notebooks/CondNet/net_my_fcn.pt')

In [None]:
plt.plot(loss_vs_epoch[1:])

In [None]:
item = 133

data, mask, _ = ds[item]

data = data.unsqueeze(0).cuda().float()
test_net.eval()
test_out = test_net(data)


fig, ax = plt.subplots(3, 3, figsize=(9,9))



ax[0, 0].imshow(test_out[0][0, :, :].cpu().data.numpy())
ax[0, 1].imshow(test_out[0][0, :, :].cpu().data.numpy())
ax[0, 2].imshow(test_out[0][0, :, :].cpu().data.numpy())

ax[1, 0].imshow((test_out[2][0, 0, :, :]).cpu().data.numpy())
ax[1, 1].imshow((test_out[2][0, 1, :, :]).cpu().data.numpy())
ax[1, 2].imshow((test_out[2][0, 2, :, :]).cpu().data.numpy())

ax[2, 0].imshow(mask[0, :, :].cpu().data.numpy())
ax[2, 1].imshow(mask[0, :, :].cpu().data.numpy())
ax[2, 2].imshow(mask[0, :, :].cpu().data.numpy())

plt.show()