### Tuning a Full-Scale Connected Skip Connections U-Net with DARTs
This notebook focuses on tuning the skip connections (in particular) between the encoding and decoding branches. These skip connections have a large impact on the final preformance of the model. Evaluated on the HAM10000 dataset.

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from collections import namedtuple
from skimage.color import rgb2gray
import matplotlib.pyplot as plt
from google.colab import files
from datetime import datetime
from scipy.stats import norm
from skimage import io
import seaborn as sns
import pandas as pd
import numpy as np
import logging
import random
import string
import heapq
import glob
import time
import math
import cv2
import sys
import ast
import os
import re

In [2]:
from torch.utils.data import Dataset, DataLoader
from skimage import data, transform, util
import torchvision.transforms as transforms
from torchsummary import torchsummary
from skimage.transform import resize
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torch

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# 1. Data

In [5]:
class HAM10000_Dataset(Dataset):

  def __init__(self, csv_file, num_entries = 10000):
    self.num_entries = num_entries
    self.data = pd.read_csv(csv_file, usecols = [ 1, 2])
    assert self.num_entries <= len(self.data)

  def __len__(self):
    return self.num_entries

  def __getitem__(self, index):
    ind = index if index < (self.num_entries // 2) else index + (len(self.data) - self.num_entries) // 2

    downscale = lambda x: F.interpolate(x, size = (128, 128), mode = 'bilinear', align_corners = True)

    img = io.imread(self.data.iloc[ind, 0])
    img_tensor = downscale(transforms.ToTensor()(img).unsqueeze(0)).squeeze(0)

    pmask = (downscale(transforms.ToTensor()(rgb2gray(io.imread(self.data.iloc[ind, 1]))).unsqueeze(0))).bool().squeeze(0)
    mask_tensor = torch.cat((pmask.float(), (~pmask).float()), dim = 0)

    return (img_tensor.float(), mask_tensor)

# 2. Operations

In [6]:
class DownSampl(nn.Module):
    '''
    Bilinear downscale for images
    '''
    def __init__(self, d_img_in, d_img_out):
        super(DownSampl, self).__init__()
        self.op = lambda x: F.interpolate(
            x,
            size = (d_img_out, d_img_out),
            mode = 'bilinear',
            align_corners = True
            )

    def forward(self, x):
        return self.op(x)

In [7]:
class UpSampl(nn.Module):

    def __init__(self, d_img_in, d_img_out):
        '''
        Bilinear upscale for images
        '''
        super(UpSampl, self).__init__()
        scale = d_img_out // d_img_in
        self.op = torch.nn.Upsample(scale_factor = scale, mode = 'bilinear')

    def forward(self, x):
        return self.op(x)

In [8]:
class Zero(nn.Module):

  def __init__(self, d_out, stride):
    super(Zero, self).__init__()
    self.stride = stride
    self.d_out = d_out

  def forward(self, x):
    if self.stride == 1:
      return torch.cat([x.mul(0.)[:, :1, :, :] for _ in range(self.d_out)], dim = 1)
    return torch.cat([x[:,:,::self.stride,::self.stride].mul(0.)[:, :1, :, :] for _ in range(self.d_out)], dim = 1)

In [9]:
class SkipConnection(nn.Module):

  def __init__(self, d_in, d_out):
    super(SkipConnection, self).__init__()
    self.d_in = d_in
    self.d_out = d_out

  def forward(self, x):
    if self.d_in == self.d_out:
        return x

    elif self.d_in < self.d_out:
        seq = np.array([[i for _ in range(self.d_out // self.d_in)] for i in range(self.d_in)]).flatten()
        return torch.cat(
            [x[:, i:i+1, :, :] for i in seq],
            dim = 1
            )

    else:
        glen = self.d_in // self.d_out
        groups = torch.split(x, glen, dim=1)
        op1 = nn.AvgPool3d(kernel_size=(glen, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0)).cuda()
        out = []
        for g in groups:
            out.append(op1(g))
        return torch.cat(out, dim=1)

In [10]:
class MaxPoolBN(nn.Module):

  def __init__(self, d_in, d_out, stride):
    super(MaxPoolBN, self).__init__()
    self.op = nn.Sequential(
      nn.MaxPool2d(3, stride = stride, padding = 1),
      nn.BatchNorm2d(d_in, affine = False)
    )

    self.d_out = d_out
    self.d_in = d_in

  def forward(self, x):
    '''
    MaxPoolBN cannot adjust the thickenss of the layers because it is a filter
    slid over each layer, we need to account for d_in != d_out. What makes this
    easier is they they are all factors of 2 larger or smaller than eachother.

    If d_out > d_in, we duplicate in the input image's channels until we reach
    the dimensionality of d_out. If d-in is larger than d_out, we apply a
    3D Max Pool groupwise to the input image to reduce dimensionality.
    '''
    out = self.op(x)

    if self.d_out > self.d_in:
      seq = np.array([[i for _ in range(self.d_out // self.d_in)] for i in range(self.d_in)]).flatten()
      return torch.cat(
          [out[:, i:i+1, :, :] for i in seq],
          dim = 1
          )

    elif self.d_out < self.d_in:
        glen = self.d_in // self.d_out
        groups = torch.split(x, glen, dim=1)
        op1 = nn.MaxPool3d(kernel_size=(glen, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)).cuda()
        out = []
        for g in groups:
            out.append(op1(g))
        out = torch.cat(out, dim=1)
        bn = nn.BatchNorm2d(self.d_out, affine = False).cuda()
        return bn(out)

    return out


In [11]:
class AvgPoolBN(nn.Module):

  def __init__(self, d_in, d_out, stride):
    super(AvgPoolBN, self).__init__()
    self.op = nn.Sequential(
      nn.AvgPool2d(3, stride = stride, padding = 1),
      nn.BatchNorm2d(d_in, affine = False)
    )

    self.d_out = d_out
    self.d_in = d_in

  def forward(self, x):
    out = self.op(x)

    if self.d_out > self.d_in:
      seq = np.array([[i for _ in range(self.d_out // self.d_in)] for i in range(self.d_in)]).flatten()
      return torch.cat(
          [out[:, i:i+1, :, :] for i in seq],
          dim = 1
          )

    elif self.d_out < self.d_in:
        glen = self.d_in // self.d_out
        groups = torch.split(x, glen, dim=1)
        op1 = nn.AvgPool3d(kernel_size=(glen, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)).cuda()
        out = []
        for g in groups:
            out.append(op1(g))
        out = torch.cat(out, dim=1)
        bn = nn.BatchNorm2d(self.d_out, affine = False).cuda()
        return bn(out)

    return out

In [12]:
class ReLUConvBN(nn.Module):

  def __init__(self, d_in, d_out, kernel_size, stride, padding, affine=True):
    super(ReLUConvBN, self).__init__()
    self.op = nn.Sequential(
      nn.ReLU(inplace = False),
      nn.Conv2d(d_in, d_out, kernel_size, stride = stride, padding = padding, bias = False),
      nn.BatchNorm2d(d_out, affine = affine)
    )

  def forward(self, x):
    return self.op(x)

In [13]:
class DilConv(nn.Module):

  def __init__(self, d_in, d_out, kernel_size, stride, padding, dilation, affine = True):
    super(DilConv, self).__init__()
    self.op = nn.Sequential(
      nn.ReLU(inplace=False),
      nn.Conv2d(d_in, d_out, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, bias = False),
      nn.BatchNorm2d(d_out, affine = affine),
      )

  def forward(self, x):
    return self.op(x)

In [14]:
class DepthSepConv(nn.Module):

  def __init__(self, d_in, d_out, kernel_size, stride, padding, affine = True):
    super(DepthSepConv, self).__init__()
    self.op = nn.Sequential(
      nn.ReLU(inplace = False),
      nn.Conv2d(d_in, d_in, kernel_size = kernel_size, stride = stride, padding = padding, groups = d_in, bias = False),
      nn.Conv2d(d_in, d_out, kernel_size = 1, padding = 0, bias = False),
      nn.BatchNorm2d(d_out, affine = affine)
    )

  def forward(self, x):
    return self.op(x)


In [15]:
class SpatialSepConv(nn.Module):

  def __init__(self, d_in, d_out, kernel_size, stride, padding, affine = True):
    super(SpatialSepConv, self).__init__()
    self.op = nn.Sequential(
      nn.ReLU(inplace=False),
      nn.Conv2d(d_in, d_in, (1, kernel_size), stride = (1, stride), padding = (0, padding), bias = False),
      nn.Conv2d(d_in, d_out, (kernel_size, 1), stride = (stride, 1), padding = (padding, 0), bias = False),
      nn.BatchNorm2d(d_out, affine = affine)
      )

  def forward(self, x):
    return self.op(x)

In [16]:
OPS = {
  'none' : lambda d_in, d_out, stride, affine: Zero(d_out, stride),
  'skip_connect' : lambda d_in, d_out, stride, affine: SkipConnection(d_in, d_out),
  'max_pool_3x3' : lambda d_in, d_out, stride, affine: MaxPoolBN(d_in, d_out, stride = stride),
  'avg_pool_3x3' : lambda d_in, d_out, stride, affine: AvgPoolBN(d_in, d_out, stride = stride),
  'dil_conv_3x3' : lambda d_in, d_out, stride, affine: DilConv(d_in, d_out, 3, stride, 2, 2, affine = affine),
  'dil_conv_5x5' : lambda d_in, d_out, stride, affine: DilConv(d_in, d_out, 5, stride, 4, 2, affine = affine),
  'depth_sep_conv_3x3' : lambda d_in, d_out, stride, affine: DepthSepConv(d_in, d_out, 3, stride, 1, affine = affine),
  'depth_sep_conv_5x5' : lambda d_in, d_out, stride, affine: DepthSepConv(d_in, d_out, 5, stride, 2, affine = affine),
  'spatial_sep_conv_5x5' : lambda d_in, d_out, stride, affine: SpatialSepConv(d_in, d_out, 5, stride, 2, affine = affine),
  'spatial_sep_conv_7x7' : lambda d_in, d_out, stride, affine: SpatialSepConv(d_in, d_out, 7, stride, 3, affine = affine)
}

PRIMITIVES = [
    'none',
    'skip_connect',
    'max_pool_3x3',
    'avg_pool_3x3',
    'dil_conv_3x3',
    'dil_conv_5x5',
    'depth_sep_conv_3x3',
    'depth_sep_conv_5x5',
    'spatial_sep_conv_5x5',
    'spatial_sep_conv_7x7'
]

# 3. Mixed Operation

In [17]:
class MixedOp(nn.Module):

  def __init__(self, d_in, d_out, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](d_in, d_out, stride, False)
      self._ops.append(op)

  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops))

# 4. Cell

In [18]:
class Cell(nn.Module):

  def __init__(self, d_in, d_out, d_img_in, d_img_out):
    super(Cell, self).__init__()

    self.num_ops = len(PRIMITIVES)

    # Dimensions of the input feature map and output feature map
    # Used in initial operations s0 to increase storage information as we compress
    self.d_in = d_in
    self.d_out = d_out

    # Dimensions of the input image size and output image size
    # Used for down/up sampling in initial operations (out of s0)
    self.d_img_in = d_img_in
    self.d_img_out = d_img_out

    # Downsample
    if d_img_in > d_img_out:
      self.c0 = DownSampl(d_img_in, d_img_out)

    # Upsample
    elif self.d_img_out > self.d_img_in:
      self.c0 = UpSampl(d_img_in, d_img_out)

    else:
      self.c0 = lambda x: x

    self._ops = nn.ModuleList()

    for i in range(3):
      for j in range(i + 1):
        if j == 0:
          op = MixedOp(self.d_in, self.d_out // 2, 1)
        else:
          op = MixedOp(self.d_out // 2, self.d_out // 2, 1)
        self._ops.append(op)

  def forward(self, s0, weights):
    # nops = num_ops = 10
    # Node 0 (s0) -> Node 1: weights (0 -> nops)
    # Node 0 (s0) -> Node 2: weights (nops -> 2 * nops)
    # Node 1 -> Node 2: weights (2 * nops -> 3 * nops)
    # Node 0 (s0) -> Node 3: weights (3 * nops -> 4 * nops)
    # Node 1 -> Node 3: weights (4 * nops -> 5 * nops)
    # Node 2 -> Node 3: weights (5 * nops -> 6 * nops)
    # Node 0, Node 1, Node 2, Node 3 -> SSB: weights (6 * nops -> 6 * nops + 4)
    # Node 0, Node 1, Node 2, Node 3 -> CSB: weights (6 * nops + 5 -> 6 * nops + 8)
    # Total Alphas per Cell: 68

    s0 = self.c0(s0.cuda())
    concat_states = [s0]
    offset = 0

    ssb_sum_weights = F.softmax(weights[-8:-4], dim = -1).cuda()
    csb_concat_weights = F.sigmoid(weights[-4:]).cuda()

    for i in range(3):

      s = sum(torch.stack([self._ops[offset + j](h, weights[(offset + j) * self.num_ops: (offset + j + 1) * self.num_ops]) for j, h in enumerate(concat_states)]))
      offset += len(concat_states)
      concat_states.append(s)

    # Adding or removing feature maps from s0 to they can be combined w/ the output
    if (self.d_out // 2) > self.d_in:
      seq = np.array([[i for _ in range((self.d_out // 2) // self.d_in)] for i in range(self.d_in)]).flatten()
      s0 = torch.cat([s0[:, i:i+1, :, :] for i in seq], dim = 1)

    elif (self.d_out // 2) < self.d_in:
      op0 = nn.Conv2d(self.d_in, self.d_out // 2, 1, groups = self.d_out // 2).cuda()
      s0 = op0(s0)

    # Update s0 after they have been used as input to other branches to be able to be combined with output SSB and CSB
    concat_states[0] = s0
    shape = concat_states[0].shape

    sums = sum([
        (torch.ones(shape).cuda() * ssb_sum_weights[i]) * concat_states[i] for i in range(len(concat_states))
        ])

    csb_pre_out = torch.cat([
        (torch.ones(concat_states[i].shape).cuda() * csb_concat_weights[i]) * concat_states[i] for i in range(len(concat_states)
        )], dim = 1)

    csb_out = nn.Conv2d(in_channels = (self.d_out // 2) * len(concat_states), out_channels = self.d_out // 2, kernel_size = 1, stride = 1, padding = 0, bias = False).cuda()(csb_pre_out)

    return torch.cat([sums, csb_out], dim = 1)

# 5. Encoder/Decoder Backbone Cell

In [19]:
class BackBoneCell(nn.Module):

  def __init__(self, d_in, d_out, d_img_in, d_img_out):
    '''
    Only the skip connections are tuned, the encoder and
    decoder cells remain the same. Implementation is the same
    as the original U-Net paper.
    '''
    super(BackBoneCell, self).__init__()

    # Downsample
    if d_img_in > d_img_out:
        self.scale = DownSampl(d_img_in, d_img_out)

    # Upsample
    elif d_img_out > d_img_in:
        self.scale = UpSampl(d_img_in, d_img_out)

    else:
        self.scale = lambda x: x

    self.op = nn.Sequential(
        nn.Conv2d(d_in, d_out, 3, padding = 1, stride = 1),
        nn.BatchNorm2d(d_out),
        nn.ReLU(inplace = False),
        nn.Conv2d(d_out, d_out, 3, padding = 1, stride = 1),
        nn.BatchNorm2d(d_out),
        nn.ReLU(inplace = False)
    )

  def forward(self, x):
    out = self.scale(x)
    return self.op(out)

# 6. Network

In [20]:
class Network(nn.Module):

  def __init__(self):
    super(Network, self).__init__()
    self.num_ops = len(PRIMITIVES)

    self.cells = nn.ModuleList()

    # U-Net will be composed of a 5-step encoder, and 4-step decoder
    # Encoders gradually compress image but grow in channels
    # Decoder layers are all 80 feature maps thick, and are composed of a concatenation of 16 feature maps from all encoder nodes
    # We output two feature maps that are compared to the foreground and background grouth truth labels; compressed via 1x1 conv block
    # OPTION: If needed, can replace upsample with upconv instead of bilinear upscale

    # Encoder 1: IMG 128x128 -> 128x128, FEAT 3 -> 8
    self.cells.append(BackBoneCell(3, 8, 128, 128))

    # Encoder 2: IMG 128x128 -> 64x64, FEAT 8 -> 16
    self.cells.append(BackBoneCell(8, 16, 128, 64))

    # Encoder 3: IMG 64x64 -> 32x32, FEAT 16 -> 32
    self.cells.append(BackBoneCell(16, 32, 64, 32))

    # Encoder 4: IMG 32x32 -> 16x16, FEAT 32 -> 64
    self.cells.append(BackBoneCell(32, 64, 32, 16))

    # Encoder 5: IMG 16x16 -> 8x8, FEAT 64 -> 128
    self.cells.append(BackBoneCell(64, 128, 16, 8))

    # ---

    # Skip_1_4: IMG 128x128 -> 16x16, FEAT 8 -> 32 (16 + 16)
    self.cells.append(Cell(8, 32, 128, 16))

    # Skip_2_4: IMG 64x64 -> 16x16, FEAT 16 -> 32 (16 + 16)
    self.cells.append(Cell(16, 32, 64, 16))

    # Skip_3_4: IMG 32x32 -> 16x16, FEAT 32 -> 32 (16 + 16)
    self.cells.append(Cell(32, 32, 32, 16))

    # Skip_4_4: IMG 16x16 -> 16x16, FEAT 64 -> 32 (16 + 16)
    self.cells.append(Cell(64, 32, 16, 16))

    # Enc_Dec_5_4: IMG 8x8 -> 16x16, FEAT 128 -> 16
    self.cells.append(BackBoneCell(128, 16, 8, 16))

    # Decoder 4: IMG 16x16, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144, 16, 16, 16))

    # ---

    # Skip_1_3: IMG 128x128 -> 32x32, FEAT 8 -> 32 (16 + 16)
    self.cells.append(Cell(8, 32, 128, 32))

    # Skip_2_3: IMG 64x64 -> 32x32, FEAT 16 -> 32 (16 + 16)
    self.cells.append(Cell(16, 32, 64, 32))

    # Skip_3_3: IMG 32x32 -> 32x32, FEAT 32 -> 32 (16 + 16)
    self.cells.append(Cell(32, 32, 32, 32))

    # Skip_4_3: IMG 16x16 -> 32x32, FEAT 64 -> 32 (16 + 16)
    self.cells.append(Cell(64, 32, 16, 32))

    # Dec_4_3: IMG 16x16 -> 32x32, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16, 16, 16, 32))

    # Decoder 3: IMG 32x32, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144, 16, 32, 32))

    # ---

    # Skip_1_2: IMG 128x128 -> 64x64, FEAT 8 -> 32 (16 + 16)
    self.cells.append(Cell(8, 32, 128, 64))

    # Skip_2_2: IMG 64x64 -> 64x64, FEAT 16 -> 32 (16 + 16)
    self.cells.append(Cell(16, 32, 64, 64))

    # Skip_3_2: IMG 32x32 -> 64x64, FEAT 32 -> 32 (16 + 16)
    self.cells.append(Cell(32, 32, 32, 64))

    # Skip_4_2: IMG 16x16 -> 64x64, FEAT 64 -> 32 (16 + 16)
    self.cells.append(Cell(64, 32, 16, 64))

    # Dec_3_2: IMG 32x32 -> 64x64, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16, 16, 32, 64))

    # Decoder 2: IMG 64x64, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144, 16, 64, 64))

    # ---

    # Skip_1_1: IMG 128x128 -> 128x128, FEAT 8 -> 32 (16 + 16)
    self.cells.append(Cell(8, 32, 128, 128))

    # Skip_2_1: IMG 64x64 -> 128x128, FEAT 16 -> 32 (16 + 16)
    self.cells.append(Cell(16, 32, 64, 128))

    # Skip_3_1: IMG 32x32 -> 128x128, FEAT 32 -> 32 (16 + 16)
    self.cells.append(Cell(32, 32, 32, 128))

    # Skip_4_1: IMG 16x16 -> 128x128, FEAT 64 -> 32 (16 + 16)
    self.cells.append(Cell(64, 32, 16, 128))

    # Enc_Dec_2_1: IMG 64x64 -> 128x128, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16, 16, 64, 128))

    # Dec_Out_1 1: IMG 128x128, FEAT 144 -> 2
    self.cells.append(BackBoneCell(144, 2, 128, 128))

    self.initialize_alphas()

  def initialize_alphas(self):
    self._arch_parameters = Variable(1e-3 * torch.randn(16, 6 * self.num_ops + 8).cuda(), requires_grad = True)

  def logmeanexp(self, x):
    x_max, _ = torch.max(x, dim = -1, keepdim=True)
    return x_max + torch.log(torch.mean(torch.exp(x - x_max), dim = -1, keepdim = True))

  def mlc_loss(self):
    n = self.num_ops
    return torch.cat((
      self.logmeanexp(self._arch_parameters[:, : n]).flatten(), # Node 1
      self.logmeanexp(self._arch_parameters[:, n : 3 * n]).flatten(), # Node 2
      self.logmeanexp(self._arch_parameters[:, 3 * n : 6 * n]).flatten() # Node 3
    ), dim = 0)

  def arch_parameters(self):
    return self._arch_parameters

  def forward(self, inp):

    # Encoder Branch
    enc1 = self.cells[0](inp)
    enc2 = self.cells[1](enc1)
    enc3 = self.cells[2](enc2)
    enc4 = self.cells[3](enc3)
    enc5 = self.cells[4](enc4)

    # Decoder 4
    skip_1_4 = self.cells[5](enc1, self._arch_parameters[0, :])
    skip_2_4 = self.cells[6](enc2, self._arch_parameters[1, :])
    skip_3_4 = self.cells[7](enc3, self._arch_parameters[2, :])
    skip_4_4 = self.cells[8](enc4, self._arch_parameters[3, :])

    enc_dec_5_4 = self.cells[9](enc5)

    dec_4 = self.cells[10](
        torch.cat((
            skip_1_4,
            skip_2_4,
            skip_3_4,
            skip_4_4,
            enc_dec_5_4
            ), dim = 1)
        )

    # Decoder 3
    skip_1_3 = self.cells[11](enc1, self._arch_parameters[4, :])
    skip_2_3 = self.cells[12](enc2, self._arch_parameters[5, :])
    skip_3_3 = self.cells[13](enc3, self._arch_parameters[6, :])
    skip_4_3 = self.cells[14](enc4, self._arch_parameters[7, :])

    dec_4_3 = self.cells[15](dec_4)

    dec_3 = self.cells[16](
        torch.cat((
            skip_1_3,
            skip_2_3,
            skip_3_3,
            skip_4_3,
            dec_4_3
            ), dim = 1)
        )

    # Decoder 2
    skip_1_2 = self.cells[17](enc1, self._arch_parameters[8, :])
    skip_2_2 = self.cells[18](enc2, self._arch_parameters[9, :])
    skip_3_2 = self.cells[19](enc3, self._arch_parameters[10, :])
    skip_4_2 = self.cells[20](enc4, self._arch_parameters[11, :])

    dec_3_2 = self.cells[21](dec_3)

    dec_2 = self.cells[22](
        torch.cat((
            skip_1_2,
            skip_2_2,
            skip_3_2,
            skip_4_2,
            dec_3_2
            ), dim = 1)
        )

    # Decoder 1
    skip_1_1 = self.cells[23](enc1, self._arch_parameters[12, :])
    skip_2_1= self.cells[24](enc2, self._arch_parameters[13, :])
    skip_3_1 = self.cells[25](enc3, self._arch_parameters[14, :])
    skip_4_1 = self.cells[26](enc4, self._arch_parameters[15, :])

    dec_2_1 = self.cells[27](dec_2)

    dec_1 = self.cells[28](
        torch.cat((
            skip_1_1,
            skip_2_1,
            skip_3_1,
            skip_4_1,
            dec_2_1
        ), dim = 1))

    return dec_1

# 7. Training Utils

In [21]:
# Stores Accuracy, IoU, and Dice data
class AvgrageMeter(object):

    def __init__(self):
        super().__init__()
        self.avg = 0
        self.sum = 0
        self.cnt = 0
        self.history = np.array([])

    def update(self, val, n = 1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt
        self.history = np.append(self.history, self.avg)

    def getAverage(self):
        return self.avg

    def getHistory(self):
        return self.history

In [22]:
# Accuracy calculation
def accuracy(inp, tar):
    acc, inp, tar = [], inp.int(), tar.int()
    for i in range(batch_size):
        for j in range(2):
            cinp, ctar = inp[i, j, :, :], tar[i, j, :, :]
            acc.append(torch.sum(cinp == ctar).item() / ctar.numel())
    return np.mean(acc)

# IoU Calculation
def iou(inp, tar):
    ciou, inp, tar = [], inp.int(), tar.int()
    for i in range(batch_size):
        for j in range(2):
            cinp, ctar = inp[i, j, :, :], tar[i, j, :, :]
            intersection = torch.sum(torch.logical_and(cinp, ctar))
            union = torch.sum(torch.logical_or(cinp, ctar))
            ciou.append(intersection.item() / union.item())
    return np.mean(ciou)


# Dice Calculation
def dice(inp, tar):
    cdice, inp, tar = [], inp.int(), tar.int()
    for i in range(batch_size):
        for j in range(2):
            cinp, ctar = inp[i, j, :, :], tar[i, j, :, :]
            intersection = torch.sum(torch.logical_and(cinp, ctar))
            total_area = torch.sum(cinp) + torch.sum(ctar)
            cdice.append(2 * intersection.item() / total_area.item())
    return np.mean(cdice)


In [23]:
# Infer - Assessing Network Preformance
def infer(valid_queue, model, criterion, discrete = False):
    model.eval()

    test_acc = AvgrageMeter()
    test_iou = AvgrageMeter()
    test_dice = AvgrageMeter()

    for step, (inputd, target) in enumerate(valid_queue):

        inputd = inputd.cuda()
        target = target.cuda(non_blocking=True)

        logits = model(inputd)
        loss = criterion(logits, target)
        logits_binary = (logits >= 0.5).int()

        if discrete:
            test_acc.update(accuracy(logits_binary[:, :1, :, :], target[:, :1, :, :]), n = batch_size)
            test_iou.update(iou(logits_binary[:, :1, :, :], target[:, :1, :, :]), n = batch_size)
            test_dice.update(dice(logits_binary[:, :1, :, :], target[:, :1, :, :]), n = batch_size)

        else:
            test_acc.update(accuracy(logits_binary, target), n = batch_size)
            test_iou.update(iou(logits_binary, target), n = batch_size)
            test_dice.update(dice(logits_binary, target), n = batch_size)

    return test_acc.getAverage(), test_iou.getAverage(), test_dice.getAverage()

In [24]:
# Training (Inner-Epoch)
def train(train_queue, valid_queue, model, a_optimizer, criterion, w_optimizer, lr, scheduler, epoch, beta_weight):

    objs_ = AvgrageMeter()
    acc_ = AvgrageMeter()
    iou_ = AvgrageMeter()
    dice_ = AvgrageMeter()

    weights = 0 + 50 * epoch / 100

    for step, (inp, tar) in enumerate(train_queue):

        model.train()
        n = inp.size(0)

        inp = inp.cuda()
        tar = tar.cuda(non_blocking=True)

        inp_search, tar_search = next(iter(valid_queue))
        inp_search = inp_search.cuda()
        tar_search = tar_search.cuda(non_blocking=True)

        a_optimizer.zero_grad()
        logits = model(inp_search)
        loss = criterion(logits, tar_search) + weights * model.mlc_loss()
        loss.mean().backward()
        a_optimizer.step()

        w_optimizer.zero_grad()
        logits = model(inp)
        loss = criterion(logits, tar)
        loss.mean().backward()
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        w_optimizer.step()

        objs_.update(loss.item())
        acc_.update(accuracy(logits, tar), n = batch_size)
        iou_.update(iou(logits, tar), n = batch_size)
        dice_.update(dice(logits, tar), n = batch_size)

        if step % rep_freq == 0:
            print("Training Step:", step,
                  "Loss:", objs_.getAverage(),
                  "Accuracy", acc_.getAverage(),
                  "IoU", iou_.getAverage(),
                  "Dice", dice_.getAverage()
                  )
            row_list.append({
                "Time": datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                "Epoch": epoch,
                "Step": step,
                "LR": lr,
                "Loss": objs_.getAverage(),
                "Accuracy": acc_.getAverage(),
                "IoU": iou_.getAverage(),
                "Dice": dice_.getAverage(),
                "Node_1_A": [],
                "Node_2_A": [],
                "Node_3_A": [],
                "SSB_A": [],
                "CSB_A": []
            })


# 8. Training
**NOTE:** The modified [search space paper](https://openreview.net/forum?id=2IkLprQjby) utilizes SGD for the filter weights and ADAM to optimize alpha values. No peudo-hessian, as included in the original DARTs paper, is used for updating weights. Likely faster, but may be less accurate.

**NOTE:** There are two output maps, one for the target class and one for the background class. Metrics like accuracy, IoU, and Dice are calculated on both of these feature maps. This is done for a more wholistic view of the network preformance. However, during inference, only the target class output will be used. We have two seperate output classes during training to allow the network to learn target and background classes independently.

In [25]:
# Hyperparameters
num_ops = len(PRIMITIVES)
rep_freq = 8
batch_size = 24
epochs = 30
lr = 0.025
lr_min = 0.0003
momentum = 0.9
arch_lr = 0.0003
beta_weight = 0.5
weight_decay = 0.0003
drop_path = 0.1
grad_clip = 5.0
data_len = 2048
seed = 2024

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [26]:
# Load Data
train_data = HAM10000_Dataset("/content/drive/MyDrive/Granados_Thesis_SP24/HAM10000/image_indexing.csv", num_entries = data_len)

# 70% Train/ 20% Validation/ 10% Test
indices = list(range(data_len))
train_val_split = int(np.floor(0.7 * len(train_data)))
val_test_split = int(np.floor(0.9 * len(train_data)))

train_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:train_val_split]),
    pin_memory = True,
    drop_last = True
    )

val_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[train_val_split:val_test_split]),
    pin_memory = True,
    drop_last = True
    )

test_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[val_test_split:]),
    pin_memory = True,
    drop_last = True
    )

In [27]:
# Logging Dataframe
columns_ = [
    "Time", "Epoch", "Step", "LR",
    "Loss", "Accuracy", "IoU", "Dice",
    "Node_1_A", "Node_2_A", "Node_3_A", "SSB_A", "CSB_A"]
row_list = []

In [None]:
torch.cuda.empty_cache()
if not torch.cuda.is_available():
  print("No GPU Device Availible")

# Model
model = Network().cuda()

# Number of Parameters
w_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
a_params = sum(p.numel() for p in model.arch_parameters())
print("Total Parameters:", w_params + a_params)

# Loss
criterion = nn.BCEWithLogitsLoss().cuda()

# Optimizers
w_optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
a_optimizer = torch.optim.RAdam([model.arch_parameters()], lr = arch_lr, betas = (0.5, 0.999))

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_optimizer, float(epochs), eta_min = lr_min)

# Train (Inter-Epoch)
for epoch in range(epochs):
    lr = scheduler.get_lr()[0]
    print(f"epoch {epoch} lr {lr}")

    train(train_queue, val_queue, model, a_optimizer, criterion, w_optimizer, lr, scheduler, epoch, beta_weight)

    with torch.no_grad():

        alpha_tensor = model.arch_parameters()

        for i, a in enumerate(alpha_tensor):

            a = a.cuda().data

            node_1_alphas = [np.round(r.item(), 4) for r in a[: num_ops]]
            node_2_alphas = [np.round(r.item(), 4) for r in a[num_ops : 3 * num_ops]]
            node_3_alphas = [np.round(r.item(), 4) for r in a[3 * num_ops : 6 * num_ops]]
            ssb_alphas = [np.round(r.item(), 4) for r in a[-8:-4]]
            csb_alphas = [np.round(r.item(), 4) for r in a[-4:]]

            row_list.append({
                "Time": datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                "Epoch": epoch,
                "Step": -1,
                "LR": lr,
                "Loss": -1,
                "Accuracy": -1,
                "IoU": -1,
                "Dice": -1,
                "Node_1_A": node_1_alphas,
                "Node_2_A": node_2_alphas,
                "Node_3_A": node_3_alphas,
                "SSB_A": ssb_alphas,
                "CSB_A": csb_alphas
            })

            print(
                "Node_1_Alphas", node_1_alphas,
                "Node_2_Alphas", node_2_alphas,
                "Node_3_Alphas", node_3_alphas,
                "SSB_Alphas", ssb_alphas,
                "CSB_Alphas", csb_alphas,
                )

        if epoch == epochs - 1:
            test_acc, test_iou, test_dice = infer(test_queue, model, criterion)
            print("Test Accuracy:", test_acc, "Test IoU:", test_iou, "Test Dice:", test_dice)
            row_list.append({
                "Time": datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                "Epoch": epoch,
                "Step": -1,
                "LR": lr,
                "Loss": -1,
                "Accuracy": test_acc,
                "IoU": test_iou,
                "Dice": test_dice,
                "Node_1_A": [],
                "Node_2_A": [],
                "Node_3_A": [],
                "SSB_A": [],
                "CSB_A": []
            })

    scheduler.step()
    ctime = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    torch.save(model.state_dict(), f"/content/drive/MyDrive/Granados_Thesis_SP24/Models/{ctime}_epoch_{epoch}.pt")

In [31]:
# Assemble & Save log
log_df = pd.DataFrame(row_list, columns = columns_)
log_df.to_csv('/content/drive/MyDrive/Granados_Thesis_SP24/log_df.csv')

# 9. Testing
**Note:** During discretization, both SSB and CSB blocks are mean-gated. Meaning we only keep connections that are above the mean alpha values of all the connections feeding into the block.   

### 9.1 Discrete Weight Parsing

In [28]:
log_df = pd.read_csv("/content/drive/MyDrive/Granados_Thesis_SP24/log_df.csv").drop(columns = ["Unnamed: 0"])

In [72]:
top_n_ops = 3
finalnet_node_tuples, finalnet_csb_tuples, finalnet_ssb_tuples = [], [], []
finalnet_df = log_df[-17:-1][["Node_1_A", "Node_2_A", "Node_3_A", "SSB_A", "CSB_A"]].applymap(lambda x: ast.literal_eval(x))

for row in finalnet_df.iterrows():
  ctuples = []
  for i, w in enumerate(row[1][:3]):
    nlar = [w.index(j) for j in heapq.nlargest(top_n_ops, w)]
    for n in nlar:
      inp_node = n // len(PRIMITIVES)
      oper = n % len(PRIMITIVES)
      ctuples.append([inp_node, i + 1, PRIMITIVES[oper]])
  finalnet_node_tuples.append(ctuples)

  ssb_mean, csb_mean = np.mean(row[1][3]), np.mean(row[1][4])
  finalnet_ssb_tuples.append([i for i, w in enumerate(row[1][3]) if w >= ssb_mean])
  finalnet_csb_tuples.append([i for i, w in enumerate(row[1][4]) if w >= csb_mean])

### 9.2 Discrete Cell

In [62]:
class DiscreteCell(nn.Module):

  def __init__(self, d_in, d_out, d_img_in, d_img_out, node_tuples, csb_tuples, ssb_tuples):
    super(DiscreteCell, self).__init__()
    '''
    node_tuples in the form (input node, output node, operation) and detail the discrete
    operation choosen.

    csb_tuples and ssb_tuples are a list of node indicies. They are used to choose
    which output nodes get drawn from.
    '''
    self.node_tuples = sorted(node_tuples, key = lambda x: x[1] if x[0] == 0 else int(str(x[0]) + str(x[1])))
    self.num_ops = len(PRIMITIVES)
    self.csb_tuples = csb_tuples
    self.ssb_tuples = ssb_tuples

    # Set of all nodes that output to another node
    self.out_set = list(set(csb_tuples + ssb_tuples + list(list(zip(*node_tuples))[0])))

    # Dimensions of the input feature map and output feature map
    # Used in initial operations s0 to increase storage information as we compress
    self.d_in = d_in
    self.d_out = d_out

    # Dimensions of the input image size and output image size
    # Used for down/up sampling in initial operations (out of s0)
    self.d_img_in = d_img_in
    self.d_img_out = d_img_out

    # Downsample
    if d_img_in > d_img_out:
      self.c0 = DownSampl(d_img_in, d_img_out)

    # Upsample
    elif self.d_img_out > self.d_img_in:
      self.c0 = UpSampl(d_img_in, d_img_out)

    else:
      self.c0 = lambda x: x

    self.nodes = [None, None, None, None]

    self.ops_dict = nn.ModuleDict()

    for inp, out, opn in self.node_tuples:
        if inp == 0:
            self.ops_dict[str(inp) + str(out) + opn] = OPS[opn](self.d_in, self.d_out // 2, 1, False).cuda()
        else:
            self.ops_dict[str(inp) + str(out) + opn] = OPS[opn](self.d_out // 2, self.d_out // 2, 1, False).cuda()

  def forward(self, s0):

    s0 = self.c0(s0.cuda())
    self.nodes[0] = s0

    for i in range(1, 4):
        inputs = filter(lambda x: x[1] == i, self.node_tuples)
        if i in self.out_set:
            tensor_list = []
            for inp, out, opn in inputs:
                tensor_list.append(self.ops_dict[str(inp) + str(out) + opn](self.nodes[inp]))
            self.nodes[i] = sum(tensor_list)

    # Adding or removing feature maps from s0 to they can be combined w/ the output
    if (self.d_out // 2) > self.d_in:
        seq = np.array([[i for _ in range((self.d_out // 2) // self.d_in)] for i in range(self.d_in)]).flatten()
        s0 = torch.cat([s0[:, i:i+1, :, :] for i in seq], dim = 1)

    elif (self.d_out // 2) < self.d_in:
        op0 = nn.Conv2d(self.d_in, self.d_out // 2, 1, groups = self.d_out // 2).cuda()
        s0 = op0(s0)

    self.nodes[0] = s0

    ssb = sum([self.nodes[i] for i in self.ssb_tuples])
    csb_pre = torch.cat([self.nodes[i] for i in self.csb_tuples], dim = 1)
    csb_out = nn.Conv2d(in_channels = (self.d_out // 2) * len(self.csb_tuples), out_channels = self.d_out // 2, kernel_size = 1, stride = 1, padding = 0, bias = False).cuda()(csb_pre)

    return torch.cat([ssb, csb_out], dim = 1)

### 9.3 Discrete Network

In [64]:
class DiscreteNetwork(nn.Module):

  def __init__(self, node_tuples, csb_tuples, ssb_tuples, n = 1):
    super(DiscreteNetwork, self).__init__()
    self.num_ops = len(PRIMITIVES)

    self.cells = nn.ModuleList()

    # U-Net will be composed of a 5-step encoder, and 4-step decoder
    # Encoders gradually compress image but grow in channels
    # Decoder layers are all 80 feature maps thick, and are composed of a concatenation of 16 feature maps from all encoder nodes
    # We output two feature maps that are compared to the foreground and background grouth truth labels; compressed via 1x1 conv block
    # OPTION: n controls how large the hidden layers are via a multiplicative factor

    # Encoder 1: IMG 128x128 -> 128x128, FEAT 3 -> 8
    self.cells.append(BackBoneCell(3, 8 * n, 128, 128))

    # Encoder 2: IMG 128x128 -> 64x64, FEAT 8 -> 16
    self.cells.append(BackBoneCell(8 * n, 16 * n, 128, 64))

    # Encoder 3: IMG 64x64 -> 32x32, FEAT 16 -> 32
    self.cells.append(BackBoneCell(16 * n, 32 * n, 64, 32))

    # Encoder 4: IMG 32x32 -> 16x16, FEAT 32 -> 64
    self.cells.append(BackBoneCell(32 * n, 64 * n, 32, 16))

    # Encoder 5: IMG 16x16 -> 8x8, FEAT 64 -> 128
    self.cells.append(BackBoneCell(64 * n, 128 * n, 16, 8))

    # ---

    # Skip_1_4: IMG 128x128 -> 16x16, FEAT 8 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(8 * n, 32 * n, 128, 16, node_tuples[0], csb_tuples[0], ssb_tuples[0]))

    # Skip_2_4: IMG 64x64 -> 16x16, FEAT 16 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(16 * n, 32 * n, 64, 16, node_tuples[1], csb_tuples[1], ssb_tuples[1]))

    # Skip_3_4: IMG 32x32 -> 16x16, FEAT 32 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(32 * n, 32 * n, 32, 16, node_tuples[2], csb_tuples[2], ssb_tuples[2]))

    # Skip_4_4: IMG 16x16 -> 16x16, FEAT 64 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(64 * n, 32 * n, 16, 16, node_tuples[3], csb_tuples[3], ssb_tuples[3]))

    # Enc_Dec_5_4: IMG 8x8 -> 16x16, FEAT 128 -> 16
    self.cells.append(BackBoneCell(128 * n, 16 * n, 8, 16))

    # Decoder 4: IMG 16x16, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144 * n, 16 * n, 16, 16))

    # ---

    # Skip_1_3: IMG 128x128 -> 32x32, FEAT 8 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(8 * n, 32 * n, 128, 32, node_tuples[4], csb_tuples[4], ssb_tuples[4]))

    # Skip_2_3: IMG 64x64 -> 32x32, FEAT 16 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(16 * n, 32 * n, 64, 32, node_tuples[5], csb_tuples[5], ssb_tuples[5]))

    # Skip_3_3: IMG 32x32 -> 32x32, FEAT 32 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(32 * n, 32 * n, 32, 32, node_tuples[6], csb_tuples[6], ssb_tuples[6]))

    # Skip_4_3: IMG 16x16 -> 32x32, FEAT 64 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(64 * n, 32 * n, 16, 32, node_tuples[7], csb_tuples[7], ssb_tuples[7]))

    # Dec_4_3: IMG 16x16 -> 32x32, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16 * n, 16 * n, 16, 32))

    # Decoder 3: IMG 32x32, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144 * n, 16 * n, 32, 32))

    # ---

    # Skip_1_2: IMG 128x128 -> 64x64, FEAT 8 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(8 * n, 32 * n, 128, 64, node_tuples[8], csb_tuples[8], ssb_tuples[8]))

    # Skip_2_2: IMG 64x64 -> 64x64, FEAT 16 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(16 * n, 32 * n, 64, 64, node_tuples[9], csb_tuples[9], ssb_tuples[9]))

    # Skip_3_2: IMG 32x32 -> 64x64, FEAT 32 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(32 * n, 32 * n, 32, 64, node_tuples[10], csb_tuples[10], ssb_tuples[10]))

    # Skip_4_2: IMG 16x16 -> 64x64, FEAT 64 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(64 * n, 32 * n, 16, 64, node_tuples[11], csb_tuples[11], ssb_tuples[11]))

    # Dec_3_2: IMG 32x32 -> 64x64, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16 * n, 16 * n, 32, 64))

    # Decoder 2: IMG 64x64, FEAT 144 -> 16
    self.cells.append(BackBoneCell(144 * n, 16 * n, 64, 64))

    # ---

    # Skip_1_1: IMG 128x128 -> 128x128, FEAT 8 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(8 * n, 32 * n, 128, 128, node_tuples[12], csb_tuples[12], ssb_tuples[12]))

    # Skip_2_1: IMG 64x64 -> 128x128, FEAT 16 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(16 * n, 32 * n, 64, 128, node_tuples[13], csb_tuples[13], ssb_tuples[13]))

    # Skip_3_1: IMG 32x32 -> 128x128, FEAT 32 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(32 * n, 32 * n, 32, 128, node_tuples[14], csb_tuples[14], ssb_tuples[14]))

    # Skip_4_1: IMG 16x16 -> 128x128, FEAT 64 -> 32 (16 + 16)
    self.cells.append(DiscreteCell(64 * n, 32 * n, 16, 128, node_tuples[15], csb_tuples[15], ssb_tuples[15]))

    # Enc_Dec_2_1: IMG 64x64 -> 128x128, FEAT 16 -> 16
    self.cells.append(BackBoneCell(16 * n, 16 * n, 64, 128))

    # Dec_Out_1 1: IMG 128x128, FEAT 144 -> 2
    self.cells.append(BackBoneCell(144 * n, 2, 128, 128))

  def forward(self, inp):

    # Encoder Branch
    enc1 = self.cells[0](inp)
    enc2 = self.cells[1](enc1)
    enc3 = self.cells[2](enc2)
    enc4 = self.cells[3](enc3)
    enc5 = self.cells[4](enc4)

    # Decoder 4
    skip_1_4 = self.cells[5](enc1)
    skip_2_4 = self.cells[6](enc2)
    skip_3_4 = self.cells[7](enc3)
    skip_4_4 = self.cells[8](enc4)

    enc_dec_5_4 = self.cells[9](enc5)

    dec_4 = self.cells[10](
        torch.cat((
            skip_1_4,
            skip_2_4,
            skip_3_4,
            skip_4_4,
            enc_dec_5_4
            ), dim = 1)
        )

    # Decoder 3
    skip_1_3 = self.cells[11](enc1)
    skip_2_3 = self.cells[12](enc2)
    skip_3_3 = self.cells[13](enc3)
    skip_4_3 = self.cells[14](enc4)

    dec_4_3 = self.cells[15](dec_4)

    dec_3 = self.cells[16](
        torch.cat((
            skip_1_3,
            skip_2_3,
            skip_3_3,
            skip_4_3,
            dec_4_3
            ), dim = 1)
        )

    # Decoder 2
    skip_1_2 = self.cells[17](enc1)
    skip_2_2 = self.cells[18](enc2)
    skip_3_2 = self.cells[19](enc3)
    skip_4_2 = self.cells[20](enc4)

    dec_3_2 = self.cells[21](dec_3)

    dec_2 = self.cells[22](
        torch.cat((
            skip_1_2,
            skip_2_2,
            skip_3_2,
            skip_4_2,
            dec_3_2
            ), dim = 1)
        )

    # Decoder 1
    skip_1_1 = self.cells[23](enc1)
    skip_2_1= self.cells[24](enc2)
    skip_3_1 = self.cells[25](enc3)
    skip_4_1 = self.cells[26](enc4)

    dec_2_1 = self.cells[27](dec_2)

    dec_1 = self.cells[28](
        torch.cat((
            skip_1_1,
            skip_2_1,
            skip_3_1,
            skip_4_1,
            dec_2_1
        ), dim = 1))

    return dec_1

### 9.4 Discrete Training & Evaluation

In [54]:
# Discrete Logging Dataframe
dcolumns_ = [
    "Time", "Epoch", "Step", "LR",
    "Loss", "Accuracy", "IoU", "Dice"
    ]
drow_list = []

In [55]:
# Load Data
data_len = 2560
train_data = HAM10000_Dataset("/content/drive/MyDrive/Granados_Thesis_SP24/HAM10000/image_indexing.csv", num_entries = data_len)

# 47.5% Train / 47.5% Validation/ 5% Test
indices = list(range(data_len))
train_val_split = int(np.floor(0.475 * len(train_data)))
val_test_split = int(np.floor(0.95 * len(train_data)))

train_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:train_val_split]),
    pin_memory = True,
    drop_last = True
    )

val_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[train_val_split:val_test_split]),
    pin_memory = True,
    drop_last = True
    )

test_queue = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size,
    sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[val_test_split:]),
    pin_memory = True,
    drop_last = True
    )

In [None]:
epochs = 40

torch.cuda.empty_cache()
if not torch.cuda.is_available():
  print("No GPU Device Availible")

# Model
discrete_model = DiscreteNetwork(finalnet_node_tuples, finalnet_csb_tuples, finalnet_ssb_tuples, n = 2).cuda()

# Number of Parameters
params = sum(p.numel() for p in discrete_model.parameters() if p.requires_grad)
print("Total Parameters:", params)

# Loss
criterion = nn.BCEWithLogitsLoss().cuda()

# Optimizers
optimizer = torch.optim.RAdam(discrete_model.parameters(), lr = lr, betas = (0.5, 0.999), weight_decay = 1e-4, decoupled_weight_decay = True)

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epochs), eta_min = lr_min)

# Train (Inter-Epoch)
for epoch in range(epochs):
    lr = scheduler.get_lr()[0]
    print(f"epoch {epoch} lr {lr}")

    objs_ = AvgrageMeter()
    acc_ = AvgrageMeter()
    iou_ = AvgrageMeter()
    dice_ = AvgrageMeter()

    for step, ((inp, tar), (inp_val, tar_val)) in enumerate(zip(train_queue, val_queue)):

        discrete_model.train()
        n = inp.size(0)

        inp = inp.cuda()
        tar = tar.cuda(non_blocking=True)

        optimizer.zero_grad()
        logits = discrete_model(inp)
        loss = criterion(logits, tar)
        loss.mean().backward()
        nn.utils.clip_grad_norm_(discrete_model.parameters(), grad_clip)
        optimizer.step()

        objs_.update(loss.item())

        discrete_model.eval()
        with torch.no_grad():
          inp_val = inp_val.cuda()
          tar_val = tar_val.cuda(non_blocking=True)

          logits_ = discrete_model(inp_val)

          acc_.update(accuracy(logits_, tar_val), n = batch_size)
          iou_.update(iou(logits_, tar_val), n = batch_size)
          dice_.update(dice(logits_, tar_val), n = batch_size)

        if step % rep_freq == 0:
            print("Training Step:", step,
                  "Loss:", objs_.getAverage(),
                  "Accuracy", acc_.getAverage(),
                  "IoU", iou_.getAverage(),
                  "Dice", dice_.getAverage()
                  )
            drow_list.append({
                "Time": datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                "Epoch": epoch,
                "Step": step,
                "LR": lr,
                "Loss": objs_.getAverage(),
                "Accuracy": acc_.getAverage(),
                "IoU": iou_.getAverage(),
                "Dice": dice_.getAverage()
            })


    with torch.no_grad():

        if epoch == epochs - 1:

            test_acc, test_iou, test_dice = infer(test_queue, discrete_model, criterion)
            print("Test Accuracy:", test_acc, "Test IoU:", test_iou, "Test Dice:", test_dice)
            drow_list.append({
                "Time": datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                "Epoch": epoch,
                "Step": -1,
                "LR": lr,
                "Loss": -1,
                "Accuracy": test_acc,
                "IoU": test_iou,
                "Dice": test_dice
            })

    scheduler.step()

# Assemble & Save log
discrete_log_df = pd.DataFrame(drow_list, columns = dcolumns_)
discrete_log_df.to_csv('/content/drive/MyDrive/Granados_Thesis_SP24/discrete_log_df.csv')

# 10. Results

DARTs-UNET Search Phase

In [10]:
log_df = pd.read_csv('/content/drive/MyDrive/Granados_Thesis_SP24/log_df.csv').drop(columns = ["Unnamed: 0"])

In [52]:
loss_overtime = log_df[log_df["Loss"] >= 0]
loss_overtime.loc[:, "epoch_step"] = loss_overtime["Epoch"] + (loss_overtime["Step"] / 64)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  loss_overtime.loc[:, "epoch_step"] = loss_overtime["Epoch"] + (loss_overtime["Step"] / 64)


In [None]:
plt.title("Network Loss Overtime")
plt.xlabel("Epoch")
plt.ylabel("BCELogitLoss")
plt.xticks(np.arange(0, 31, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
plt.plot(loss_overtime["epoch_step"], loss_overtime["Loss"])
plt.scatter(loss_overtime["epoch_step"][::40], loss_overtime["Loss"][::40])
for x, y in zip(loss_overtime["epoch_step"][::40], loss_overtime["Loss"][::40]):
    plt.text(x + 0.4, y,
        s = "Loss: " + str(round(y, 3)),
        bbox = dict(facecolor = 'grey', edgecolor = 'none', boxstyle = 'round,pad=0.05', alpha = 0.2),
        rotation = 15,
        horizontalalignment = "left",
        verticalalignment = "bottom",
        rotation_mode = "anchor",
        )
plt.show()

In [None]:
plt.title("Network Accuracy Overtime")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.xticks(np.arange(0, 31, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
plt.plot(loss_overtime["epoch_step"], loss_overtime["Accuracy"])
plt.scatter(loss_overtime["epoch_step"][::40], loss_overtime["Accuracy"][::40])
for x, y in zip(loss_overtime["epoch_step"][::40], loss_overtime["Accuracy"][::40]):
    plt.text(x + 0.4, y - 0.01,
        s = "Acc: " + str(round(y, 3)),
        bbox = dict(facecolor = 'grey', edgecolor = 'none', boxstyle = 'round,pad=0.05', alpha = 0.2),
        rotation = -30,
        horizontalalignment = "left",
        verticalalignment = "bottom",
        rotation_mode = "anchor",
        )
plt.show()

In [None]:
plt.title("Network IoU Overtime")
plt.xlabel("Epoch")
plt.ylabel("IoU")
plt.xticks(np.arange(0, 31, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
plt.plot(loss_overtime["epoch_step"], loss_overtime["IoU"])
plt.scatter(loss_overtime["epoch_step"][::40], loss_overtime["IoU"][::40])
for x, y in zip(loss_overtime["epoch_step"][::40], loss_overtime["IoU"][::40]):
    plt.text(x + 0.4, y - 0.01,
        s = "IoU: " + str(round(y, 3)),
        bbox = dict(facecolor = 'grey', edgecolor = 'none', boxstyle = 'round,pad=0.05', alpha = 0.2),
        rotation = -30,
        horizontalalignment = "left",
        verticalalignment = "bottom",
        rotation_mode = "anchor",
        )
plt.show()

In [None]:
plt.title("Network Dice Overtime")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.xticks(np.arange(0, 31, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
plt.plot(loss_overtime["epoch_step"], loss_overtime["Dice"])
plt.scatter(loss_overtime["epoch_step"][::40], loss_overtime["Dice"][::40])
for x, y in zip(loss_overtime["epoch_step"][::40], loss_overtime["Dice"][::40]):
    plt.text(x + 0.4, y - 0.01,
        s = "Dice: " + str(round(y, 3)),
        bbox = dict(facecolor = 'grey', edgecolor = 'none', boxstyle = 'round,pad=0.05', alpha = 0.2),
        rotation = -30,
        horizontalalignment = "left",
        verticalalignment = "bottom",
        rotation_mode = "anchor",
        )
plt.show()

In [74]:
alphas_df = log_df[log_df["SSB_A"].apply(lambda x: x != "[]")]

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["SSB_A"].tolist()))[::16]
tmp_df = pd.DataFrame(tmp_lst, columns = ["Node 0", "Node 1", "Node 2", "Node 3"])
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_1_4 SSB Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["SSB_A"].tolist()))[10::16]
tmp_df = pd.DataFrame(tmp_lst, columns = ["Node 0", "Node 1", "Node 2", "Node 3"])
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_3_2 SSB Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["CSB_A"].tolist()))[::16]
tmp_df = pd.DataFrame(tmp_lst, columns = ["Node 0", "Node 1", "Node 2", "Node 3"])
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_1_4 CSB Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["CSB_A"].tolist()))[10::16]
tmp_df = pd.DataFrame(tmp_lst, columns = ["Node 0", "Node 1", "Node 2", "Node 3"])
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_3_2 CSB Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["Node_1_A"].tolist()))[::16]
tmp_df = pd.DataFrame(tmp_lst, columns = PRIMITIVES)
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_1_4 Node_1 Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

In [None]:
tmp_lst = list(map(lambda x : ast.literal_eval(x), alphas_df["Node_1_A"].tolist()))[10::16]
tmp_df = pd.DataFrame(tmp_lst, columns = PRIMITIVES)
tmp_df["epoch"] = pd.Series(np.arange(30))

plt.title("Skip_3_2 Node_1 Alphas Overtime")
plt.xlabel("Epoch")
plt.ylabel("Connection Strength")
plt.xticks(np.arange(0, 30, 2))
plt.grid(color = 'grey', linestyle = '--', linewidth = 1, alpha = 0.3)
for c in tmp_df.columns[:-1]:
    plt.plot(tmp_df["epoch"], tmp_df[c], label = c)

plt.legend()
plt.show()

Discretization Phase