# <center> IMPORT MODULES </center> 

In [2]:
pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
import torch
import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

import wandb
from wandb.keras import WandbCallback

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# <center> DIFFERENTIAL CONVOLUTION </center>

In [4]:
class DiffConv4(nn.Module):

    def __init__(self):
        super(DiffConv4, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
                
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*4,sx,sy, device = x.device)

        for i in range(0,ins):

            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            #area = self.output[i,oM:oN,0:sx-1,0:sy]
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            #area = ptr[oM:oN,0:sx,0:sy-1]
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
#             # SOUTH WEST
#             oM = 4*n
#             oN = 5*n
           
#             #area = ptr[oM:oN,0:sx-1,0:sy-1]
#             self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
#             self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            

        oM = n
        oN = 4*n

        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #print(self.output)

        return self.output


In [5]:
class DiffConv5(nn.Module):

    def __init__(self):
        super(DiffConv5, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
        
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*5,sx,sy, device = x.device)

        for i in range(0,ins):


            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            #area = self.output[i,oM:oN,0:sx-1,0:sy]
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            #area = ptr[oM:oN,0:sx,0:sy-1]
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # SOUTH WEST
            oM = 4*n
            oN = 5*n
           
            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            

        oM = n
        oN = 5*n

        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #print(self.output)

        return self.output


In [6]:
class DiffConv6(nn.Module):

    def __init__(self):
        super(DiffConv6, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
        
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*6,sx,sy, device = x.device)
        dmax = x.max()
        dmin = x.min()
        
        for i in range(0,ins):


            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            #area = self.output[i,oM:oN,0:sx-1,0:sy]
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            #area = ptr[oM:oN,0:sx,0:sy-1]
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # SOUTH WEST
            oM = 4*n
            oN = 5*n
           
            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            # RIGHT - SOUTH EAST - BOTTOM
            oM = 5*n
            oN = 6*n
           
            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            #SCALE [-765 to 255] to [0 to 255]
            
            amin = self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach().min()
            amax = self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach().max()
            # Scale: -255 to 85
            self.output[i,oM:oN,0:sx-1,0:sy-1] = ( (self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach()) / 3 ).int()
            

        oM = n
        oN = 6*n

           
        #self.signInputs = self.output.sign()
        #self.signInputs[0:ins,0:n,:,:] = torch.ones(ins,n,sx,sy)
        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #self.output[0:ins,oM:oN,:,:] = torch.sqrt(self.output[0:ins,oM:oN] * self.output[0:ins,oM:oN].clone().detach())

        #print(self.output)

        return self.output

# <center> WandB Init </center>

In [7]:
# Set up Weights and Biases
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# <center> DATASET </center>

In [8]:
# Define transforms for data preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

# Load CIFAR10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# <center> Network Architectures </center>

In [9]:
multiplier = 8

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 8*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(8*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(8*multiplier, 16*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(16*multiplier)
        
        self.conv3 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(32*multiplier)
        
        self.fc1 = nn.Linear(32*multiplier * 4 * 4, 128*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(128*multiplier, 64*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(64*multiplier, 10)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 32*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

In [10]:
m = CNN()
summary(m, (3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 32, 32]          1,792
├─BatchNorm2d: 1-2                       [-1, 64, 32, 32]          128
├─MaxPool2d: 1-3                         [-1, 64, 16, 16]          --
├─Conv2d: 1-4                            [-1, 128, 16, 16]         73,856
├─BatchNorm2d: 1-5                       [-1, 128, 16, 16]         256
├─MaxPool2d: 1-6                         [-1, 128, 8, 8]           --
├─Conv2d: 1-7                            [-1, 256, 8, 8]           295,168
├─BatchNorm2d: 1-8                       [-1, 256, 8, 8]           512
├─MaxPool2d: 1-9                         [-1, 256, 4, 4]           --
├─Linear: 1-10                           [-1, 1024]                4,195,328
├─Dropout: 1-11                          [-1, 1024]                --
├─Linear: 1-12                           [-1, 512]                 524,800
├─Dropout: 1-13                          [-1, 512]        

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 32, 32]          1,792
├─BatchNorm2d: 1-2                       [-1, 64, 32, 32]          128
├─MaxPool2d: 1-3                         [-1, 64, 16, 16]          --
├─Conv2d: 1-4                            [-1, 128, 16, 16]         73,856
├─BatchNorm2d: 1-5                       [-1, 128, 16, 16]         256
├─MaxPool2d: 1-6                         [-1, 128, 8, 8]           --
├─Conv2d: 1-7                            [-1, 256, 8, 8]           295,168
├─BatchNorm2d: 1-8                       [-1, 256, 8, 8]           512
├─MaxPool2d: 1-9                         [-1, 256, 4, 4]           --
├─Linear: 1-10                           [-1, 1024]                4,195,328
├─Dropout: 1-11                          [-1, 1024]                --
├─Linear: 1-12                           [-1, 512]                 524,800
├─Dropout: 1-13                          [-1, 512]        

In [11]:
multiplier = 4

# CNN with DiffConv4
class NetDiffConv4(nn.Module):
    def __init__(self):
        super(NetDiffConv4, self).__init__()
        
        self.diff = DiffConv4()
        
        self.conv1 = nn.Conv2d(12, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


# CNN with DiffConv5
class NetDiffConv5(nn.Module):
    def __init__(self):
        super(NetDiffConv5, self).__init__()
        self.diff = DiffConv5()
        
        self.conv1 = nn.Conv2d(15, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x



# CNN with DiffConv6
class NetDiffConv6(nn.Module):
    def __init__(self):
        super(NetDiffConv6, self).__init__()
        self.diff = DiffConv6()

        self.conv1 = nn.Conv2d(18, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

# <center> WandB Config </center>

In [12]:
sweep_config = {
    'method': 'grid',
    
    'metric' : {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    
    'parameters' : {
        'model' : {
            'values': ['CNN', 'NetDiffConv4', 'NetDiffConv5', 'NetDiffConv6']
                  }
    }
    
    }

pprint.pprint(sweep_config)

{'method': 'grid',
 'metric': {'goal': 'minimize', 'name': 'val_loss'},
 'parameters': {'model': {'values': ['CNN',
                                     'NetDiffConv4',
                                     'NetDiffConv5',
                                     'NetDiffConv6']}}}


In [13]:
sweep_id = wandb.sweep(sweep_config, project="Diffconv_Cifar10_3")

Create sweep with ID: vkwu2aor
Sweep URL: https://wandb.ai/shiwayz/Diffconv_Cifar10_3/sweeps/vkwu2aor


# <center> TRAINING ++ </center>

In [14]:
# Train loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in tqdm(dataloader, desc='Training', leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # compute statistics
        running_loss += loss.item() * inputs.size(0)
        predicted = outputs.argmax(dim=1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

# Test loop
def test(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc='Testing', leave=False):
            inputs, targets = inputs.to(device), targets.to(device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # compute statistics
            running_loss += loss.item() * inputs.size(0)
            predicted = outputs.argmax(dim=1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [15]:
batch_size = 16
learning_rate = 0.0001
num_epochs = 30
# optimizer = 'Adam'
lr_step_size = 10
lr_gamma = 0.1
betas = (0.85, 0.999)
amsgrad = True

In [16]:
def main(config = None):
    
    # Define data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    # Train and evaluate models
    with wandb.init(config = config):
        config = wandb.config
        model_name = config.model

        if config.model == 'CNN':
            model = CNN().to(device)
        elif config.model == 'NetDiffConv4':
            model = NetDiffConv4().to(device)
        elif config.model == 'NetDiffConv5':
            model = NetDiffConv5().to(device)
        elif config.model == 'NetDiffConv6':
            model = NetDiffConv6().to(device)

        print(f'Training {model_name} model...')

        optimizer = optim.Adam(model.parameters(),
                               lr=learning_rate,
                               betas = betas,
                               amsgrad = amsgrad,
                              )

    # LEARNING RATE SCHEDULERS
    #     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=num_epochs//3, T_mult=2, eta_min=0)

        criterion = nn.CrossEntropyLoss()
        best_val_acc = 0.0
        for epoch in range(1, num_epochs+1):
            scheduler.step()
            train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
            val_loss, val_acc = test(model, test_loader, criterion, device)
            print(f'Epoch {epoch}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            # log results to weights and biases
            wandb.log({'train_loss': train_loss,
                       'train_acc': train_acc,
                       'val_loss': val_loss,
                       'val_acc': val_acc})
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'{model_name}.pt')
        print(f'{model_name} model finished training with best validation accuracy of {best_val_acc:.4f}')


In [17]:
wandb.agent(sweep_id, main, count=4)

[34m[1mwandb[0m: Agent Starting Run: 0q47vtgl with config:
[34m[1mwandb[0m: 	model: CNN
[34m[1mwandb[0m: Currently logged in as: [33mshiwayz[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training CNN model...


                                                             

Epoch 1/30, Train Loss: 1.5746, Train Acc: 0.4182, Val Loss: 1.2081, Val Acc: 0.5539


                                                             

Epoch 2/30, Train Loss: 1.2789, Train Acc: 0.5344, Val Loss: 1.0222, Val Acc: 0.6339


                                                             

Epoch 3/30, Train Loss: 1.1559, Train Acc: 0.5830, Val Loss: 0.9826, Val Acc: 0.6407


                                                             

Epoch 4/30, Train Loss: 1.0761, Train Acc: 0.6164, Val Loss: 0.8629, Val Acc: 0.6907


                                                             

Epoch 5/30, Train Loss: 1.0127, Train Acc: 0.6401, Val Loss: 0.8282, Val Acc: 0.7055


                                                             

Epoch 6/30, Train Loss: 0.9682, Train Acc: 0.6544, Val Loss: 0.8056, Val Acc: 0.7146


                                                             

Epoch 7/30, Train Loss: 0.9342, Train Acc: 0.6689, Val Loss: 0.7809, Val Acc: 0.7230


                                                             

Epoch 8/30, Train Loss: 0.9113, Train Acc: 0.6759, Val Loss: 0.7692, Val Acc: 0.7281


                                                             

Epoch 9/30, Train Loss: 0.9016, Train Acc: 0.6796, Val Loss: 0.7635, Val Acc: 0.7287


                                                             

Epoch 10/30, Train Loss: 0.9794, Train Acc: 0.6525, Val Loss: 0.8103, Val Acc: 0.7139


                                                             

Epoch 11/30, Train Loss: 0.9376, Train Acc: 0.6675, Val Loss: 0.7730, Val Acc: 0.7250


                                                             

Epoch 12/30, Train Loss: 0.9070, Train Acc: 0.6800, Val Loss: 0.7386, Val Acc: 0.7387


                                                             

Epoch 13/30, Train Loss: 0.8780, Train Acc: 0.6899, Val Loss: 0.7196, Val Acc: 0.7437


                                                             

Epoch 14/30, Train Loss: 0.8568, Train Acc: 0.6976, Val Loss: 0.7276, Val Acc: 0.7422


                                                             

Epoch 15/30, Train Loss: 0.8287, Train Acc: 0.7081, Val Loss: 0.7008, Val Acc: 0.7516


                                                             

Epoch 16/30, Train Loss: 0.8089, Train Acc: 0.7162, Val Loss: 0.6539, Val Acc: 0.7699


                                                             

Epoch 17/30, Train Loss: 0.7859, Train Acc: 0.7231, Val Loss: 0.6645, Val Acc: 0.7710


                                                             

Epoch 18/30, Train Loss: 0.7676, Train Acc: 0.7294, Val Loss: 0.6354, Val Acc: 0.7797


                                                             

Epoch 19/30, Train Loss: 0.7471, Train Acc: 0.7394, Val Loss: 0.6208, Val Acc: 0.7837


                                                             

Epoch 20/30, Train Loss: 0.7308, Train Acc: 0.7416, Val Loss: 0.6156, Val Acc: 0.7861


                                                             

Epoch 21/30, Train Loss: 0.7129, Train Acc: 0.7488, Val Loss: 0.6056, Val Acc: 0.7914


                                                             

Epoch 22/30, Train Loss: 0.7042, Train Acc: 0.7527, Val Loss: 0.5953, Val Acc: 0.7928


                                                             

Epoch 23/30, Train Loss: 0.6912, Train Acc: 0.7575, Val Loss: 0.5844, Val Acc: 0.7980


                                                             

Epoch 24/30, Train Loss: 0.6762, Train Acc: 0.7635, Val Loss: 0.5885, Val Acc: 0.7920


                                                             

Epoch 25/30, Train Loss: 0.6693, Train Acc: 0.7648, Val Loss: 0.5753, Val Acc: 0.7994


                                                             

Epoch 26/30, Train Loss: 0.6579, Train Acc: 0.7689, Val Loss: 0.5744, Val Acc: 0.7970


                                                             

Epoch 27/30, Train Loss: 0.6578, Train Acc: 0.7696, Val Loss: 0.5733, Val Acc: 0.7996


                                                             

Epoch 28/30, Train Loss: 0.6559, Train Acc: 0.7715, Val Loss: 0.5702, Val Acc: 0.7990


                                                             

Epoch 29/30, Train Loss: 0.6476, Train Acc: 0.7734, Val Loss: 0.5664, Val Acc: 0.8033


                                                             

Epoch 30/30, Train Loss: 0.7427, Train Acc: 0.7384, Val Loss: 0.6361, Val Acc: 0.7757
CNN model finished training with best validation accuracy of 0.8033


VBox(children=(Label(value='0.073 MB of 0.073 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_acc,▁▃▄▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇█████████▇
train_loss,█▆▅▄▄▃▃▃▃▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂
val_acc,▁▃▃▅▅▆▆▆▆▅▆▆▆▆▇▇▇▇▇██████████▇
val_loss,█▆▆▄▄▄▃▃▃▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂

0,1
train_acc,0.73842
train_loss,0.74272
val_acc,0.7757
val_loss,0.63606


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: hokcc84w with config:
[34m[1mwandb[0m: 	model: NetDiffConv4


Training NetDiffConv4 model...


                                                             

Epoch 1/30, Train Loss: 1.4453, Train Acc: 0.4687, Val Loss: 1.0747, Val Acc: 0.6052


                                                             

Epoch 2/30, Train Loss: 1.0920, Train Acc: 0.6084, Val Loss: 0.8753, Val Acc: 0.6902


                                                             

Epoch 3/30, Train Loss: 0.9578, Train Acc: 0.6592, Val Loss: 0.7992, Val Acc: 0.7139


                                                             

Epoch 4/30, Train Loss: 0.8764, Train Acc: 0.6888, Val Loss: 0.7284, Val Acc: 0.7433


                                                             

Epoch 5/30, Train Loss: 0.8097, Train Acc: 0.7132, Val Loss: 0.6608, Val Acc: 0.7682


                                                             

Epoch 6/30, Train Loss: 0.7589, Train Acc: 0.7316, Val Loss: 0.6305, Val Acc: 0.7801


                                                             

Epoch 7/30, Train Loss: 0.7176, Train Acc: 0.7508, Val Loss: 0.6057, Val Acc: 0.7865


                                                             

Epoch 8/30, Train Loss: 0.6911, Train Acc: 0.7580, Val Loss: 0.5996, Val Acc: 0.7879


                                                             

Epoch 9/30, Train Loss: 0.6746, Train Acc: 0.7635, Val Loss: 0.5897, Val Acc: 0.7921


                                                             

Epoch 10/30, Train Loss: 0.7923, Train Acc: 0.7217, Val Loss: 0.6540, Val Acc: 0.7709


                                                             

Epoch 11/30, Train Loss: 0.7550, Train Acc: 0.7329, Val Loss: 0.6362, Val Acc: 0.7765


                                                             

Epoch 12/30, Train Loss: 0.7217, Train Acc: 0.7453, Val Loss: 0.6108, Val Acc: 0.7838


                                                             

Epoch 13/30, Train Loss: 0.6917, Train Acc: 0.7558, Val Loss: 0.5685, Val Acc: 0.8039


                                                             

Epoch 14/30, Train Loss: 0.6703, Train Acc: 0.7652, Val Loss: 0.5633, Val Acc: 0.8070


                                                             

Epoch 15/30, Train Loss: 0.6441, Train Acc: 0.7739, Val Loss: 0.5472, Val Acc: 0.8127


                                                             

Epoch 16/30, Train Loss: 0.6226, Train Acc: 0.7820, Val Loss: 0.5341, Val Acc: 0.8150


                                                             

Epoch 17/30, Train Loss: 0.6043, Train Acc: 0.7867, Val Loss: 0.5278, Val Acc: 0.8170


                                                             

Epoch 18/30, Train Loss: 0.5829, Train Acc: 0.7954, Val Loss: 0.5337, Val Acc: 0.8208


                                                             

Epoch 19/30, Train Loss: 0.5652, Train Acc: 0.8010, Val Loss: 0.4911, Val Acc: 0.8302


                                                             

Epoch 20/30, Train Loss: 0.5495, Train Acc: 0.8054, Val Loss: 0.4786, Val Acc: 0.8331


                                                             

Epoch 21/30, Train Loss: 0.5279, Train Acc: 0.8139, Val Loss: 0.4770, Val Acc: 0.8364


                                                             

Epoch 22/30, Train Loss: 0.5112, Train Acc: 0.8212, Val Loss: 0.4618, Val Acc: 0.8403


                                                             

Epoch 23/30, Train Loss: 0.4977, Train Acc: 0.8259, Val Loss: 0.4621, Val Acc: 0.8432


                                                             

Epoch 24/30, Train Loss: 0.4819, Train Acc: 0.8322, Val Loss: 0.4454, Val Acc: 0.8491


                                                             

Epoch 25/30, Train Loss: 0.4678, Train Acc: 0.8352, Val Loss: 0.4422, Val Acc: 0.8496


                                                             

Epoch 26/30, Train Loss: 0.4574, Train Acc: 0.8409, Val Loss: 0.4343, Val Acc: 0.8515


                                                             

Epoch 27/30, Train Loss: 0.4527, Train Acc: 0.8405, Val Loss: 0.4349, Val Acc: 0.8525


                                                             

Epoch 28/30, Train Loss: 0.4488, Train Acc: 0.8421, Val Loss: 0.4354, Val Acc: 0.8520


                                                             

Epoch 29/30, Train Loss: 0.4457, Train Acc: 0.8426, Val Loss: 0.4322, Val Acc: 0.8547


                                                             

Epoch 30/30, Train Loss: 0.5670, Train Acc: 0.8001, Val Loss: 0.5065, Val Acc: 0.8290
NetDiffConv4 model finished training with best validation accuracy of 0.8547


VBox(children=(Label(value='0.086 MB of 0.086 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_acc,▁▄▅▅▆▆▆▆▇▆▆▆▆▇▇▇▇▇▇▇▇████████▇
train_loss,█▆▅▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂
val_acc,▁▃▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇████████▇
val_loss,█▆▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂

0,1
train_acc,0.80006
train_loss,0.56703
val_acc,0.829
val_loss,0.50647


[34m[1mwandb[0m: Agent Starting Run: 0zcqlilp with config:
[34m[1mwandb[0m: 	model: NetDiffConv5


Training NetDiffConv5 model...


                                                             

Epoch 1/30, Train Loss: 1.4464, Train Acc: 0.4696, Val Loss: 1.0077, Val Acc: 0.6384


                                                             

Epoch 2/30, Train Loss: 1.0763, Train Acc: 0.6139, Val Loss: 0.8288, Val Acc: 0.7030


                                                             

Epoch 3/30, Train Loss: 0.9430, Train Acc: 0.6653, Val Loss: 0.7665, Val Acc: 0.7293


                                                             

Epoch 4/30, Train Loss: 0.8634, Train Acc: 0.6933, Val Loss: 0.6926, Val Acc: 0.7521


                                                             

Epoch 5/30, Train Loss: 0.7895, Train Acc: 0.7210, Val Loss: 0.6744, Val Acc: 0.7659


                                                             

Epoch 6/30, Train Loss: 0.7393, Train Acc: 0.7403, Val Loss: 0.6164, Val Acc: 0.7852


                                                             

Epoch 7/30, Train Loss: 0.7009, Train Acc: 0.7524, Val Loss: 0.6058, Val Acc: 0.7872


                                                             

Epoch 8/30, Train Loss: 0.6762, Train Acc: 0.7606, Val Loss: 0.5846, Val Acc: 0.7935


                                                             

Epoch 9/30, Train Loss: 0.6590, Train Acc: 0.7656, Val Loss: 0.5751, Val Acc: 0.7980


                                                             

Epoch 10/30, Train Loss: 0.7724, Train Acc: 0.7295, Val Loss: 0.6520, Val Acc: 0.7718


                                                             

Epoch 11/30, Train Loss: 0.7341, Train Acc: 0.7409, Val Loss: 0.6067, Val Acc: 0.7892


                                                             

Epoch 12/30, Train Loss: 0.7134, Train Acc: 0.7509, Val Loss: 0.5747, Val Acc: 0.8005


                                                             

Epoch 13/30, Train Loss: 0.6799, Train Acc: 0.7607, Val Loss: 0.5900, Val Acc: 0.7931


                                                             

Epoch 14/30, Train Loss: 0.6518, Train Acc: 0.7707, Val Loss: 0.5501, Val Acc: 0.8111


                                                             

Epoch 15/30, Train Loss: 0.6319, Train Acc: 0.7789, Val Loss: 0.5335, Val Acc: 0.8179


                                                             

Epoch 16/30, Train Loss: 0.6136, Train Acc: 0.7830, Val Loss: 0.5245, Val Acc: 0.8199


                                                             

Epoch 17/30, Train Loss: 0.5858, Train Acc: 0.7941, Val Loss: 0.5173, Val Acc: 0.8249


                                                             

Epoch 18/30, Train Loss: 0.5738, Train Acc: 0.7986, Val Loss: 0.5057, Val Acc: 0.8267


                                                             

Epoch 19/30, Train Loss: 0.5594, Train Acc: 0.8037, Val Loss: 0.4748, Val Acc: 0.8360


                                                             

Epoch 20/30, Train Loss: 0.5305, Train Acc: 0.8135, Val Loss: 0.4910, Val Acc: 0.8361


                                                             

Epoch 21/30, Train Loss: 0.5147, Train Acc: 0.8196, Val Loss: 0.4595, Val Acc: 0.8441


                                                             

Epoch 22/30, Train Loss: 0.4986, Train Acc: 0.8257, Val Loss: 0.4498, Val Acc: 0.8449


                                                             

Epoch 23/30, Train Loss: 0.4803, Train Acc: 0.8315, Val Loss: 0.4391, Val Acc: 0.8512


                                                             

Epoch 24/30, Train Loss: 0.4709, Train Acc: 0.8339, Val Loss: 0.4368, Val Acc: 0.8505


                                                             

Epoch 25/30, Train Loss: 0.4588, Train Acc: 0.8378, Val Loss: 0.4306, Val Acc: 0.8516


                                                             

Epoch 26/30, Train Loss: 0.4493, Train Acc: 0.8425, Val Loss: 0.4312, Val Acc: 0.8526


                                                             

Epoch 27/30, Train Loss: 0.4415, Train Acc: 0.8446, Val Loss: 0.4269, Val Acc: 0.8552


                                                             

Epoch 28/30, Train Loss: 0.4377, Train Acc: 0.8458, Val Loss: 0.4231, Val Acc: 0.8549


                                                             

Epoch 29/30, Train Loss: 0.4349, Train Acc: 0.8477, Val Loss: 0.4241, Val Acc: 0.8558


                                                             

Epoch 30/30, Train Loss: 0.5550, Train Acc: 0.8047, Val Loss: 0.4853, Val Acc: 0.8321
NetDiffConv5 model finished training with best validation accuracy of 0.8558


VBox(children=(Label(value='0.099 MB of 0.099 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_acc,▁▄▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████▇
train_loss,█▅▅▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂
val_acc,▁▃▄▅▅▆▆▆▆▅▆▆▆▇▇▇▇▇▇▇█████████▇
val_loss,█▆▅▄▄▃▃▃▃▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂

0,1
train_acc,0.80472
train_loss,0.55503
val_acc,0.8321
val_loss,0.48528


[34m[1mwandb[0m: Agent Starting Run: 5q1ph9jq with config:
[34m[1mwandb[0m: 	model: NetDiffConv6


Training NetDiffConv6 model...


                                                             

Epoch 1/30, Train Loss: 1.4402, Train Acc: 0.4710, Val Loss: 0.9967, Val Acc: 0.6418


                                                             

Epoch 2/30, Train Loss: 1.0764, Train Acc: 0.6141, Val Loss: 0.8690, Val Acc: 0.6889


                                                             

Epoch 3/30, Train Loss: 0.9361, Train Acc: 0.6655, Val Loss: 0.7674, Val Acc: 0.7276


                                                             

Epoch 4/30, Train Loss: 0.8521, Train Acc: 0.6989, Val Loss: 0.7059, Val Acc: 0.7532


                                                             

Epoch 5/30, Train Loss: 0.7914, Train Acc: 0.7188, Val Loss: 0.6394, Val Acc: 0.7738


                                                             

Epoch 6/30, Train Loss: 0.7373, Train Acc: 0.7377, Val Loss: 0.6115, Val Acc: 0.7848


                                                             

Epoch 7/30, Train Loss: 0.6941, Train Acc: 0.7537, Val Loss: 0.6010, Val Acc: 0.7900


                                                             

Epoch 8/30, Train Loss: 0.6689, Train Acc: 0.7633, Val Loss: 0.5764, Val Acc: 0.7968


                                                             

Epoch 9/30, Train Loss: 0.6506, Train Acc: 0.7703, Val Loss: 0.5697, Val Acc: 0.7991


                                                             

Epoch 10/30, Train Loss: 0.7672, Train Acc: 0.7268, Val Loss: 0.6672, Val Acc: 0.7656


                                                             

Epoch 11/30, Train Loss: 0.7344, Train Acc: 0.7407, Val Loss: 0.6017, Val Acc: 0.7913


                                                             

Epoch 12/30, Train Loss: 0.7013, Train Acc: 0.7525, Val Loss: 0.6030, Val Acc: 0.7849


                                                             

Epoch 13/30, Train Loss: 0.6760, Train Acc: 0.7626, Val Loss: 0.6316, Val Acc: 0.7855


                                                             

Epoch 14/30, Train Loss: 0.6505, Train Acc: 0.7712, Val Loss: 0.5895, Val Acc: 0.7997


                                                             

Epoch 15/30, Train Loss: 0.6240, Train Acc: 0.7811, Val Loss: 0.5362, Val Acc: 0.8153


                                                             

Epoch 16/30, Train Loss: 0.6045, Train Acc: 0.7876, Val Loss: 0.5325, Val Acc: 0.8176


                                                             

Epoch 17/30, Train Loss: 0.5830, Train Acc: 0.7965, Val Loss: 0.5122, Val Acc: 0.8250


                                                             

Epoch 18/30, Train Loss: 0.5603, Train Acc: 0.8028, Val Loss: 0.4997, Val Acc: 0.8322


                                                             

Epoch 19/30, Train Loss: 0.5376, Train Acc: 0.8112, Val Loss: 0.4755, Val Acc: 0.8364


                                                             

Epoch 20/30, Train Loss: 0.5247, Train Acc: 0.8148, Val Loss: 0.4885, Val Acc: 0.8342


                                                             

Epoch 21/30, Train Loss: 0.5051, Train Acc: 0.8229, Val Loss: 0.4795, Val Acc: 0.8344


                                                             

Epoch 22/30, Train Loss: 0.4940, Train Acc: 0.8276, Val Loss: 0.4701, Val Acc: 0.8421


                                                             

Epoch 23/30, Train Loss: 0.4749, Train Acc: 0.8324, Val Loss: 0.4479, Val Acc: 0.8481


                                                             

Epoch 24/30, Train Loss: 0.4597, Train Acc: 0.8389, Val Loss: 0.4477, Val Acc: 0.8484


                                                             

Epoch 25/30, Train Loss: 0.4492, Train Acc: 0.8452, Val Loss: 0.4392, Val Acc: 0.8506


                                                             

Epoch 26/30, Train Loss: 0.4369, Train Acc: 0.8454, Val Loss: 0.4347, Val Acc: 0.8530


                                                             

Epoch 27/30, Train Loss: 0.4335, Train Acc: 0.8475, Val Loss: 0.4290, Val Acc: 0.8526


                                                             

Epoch 28/30, Train Loss: 0.4294, Train Acc: 0.8497, Val Loss: 0.4283, Val Acc: 0.8558


                                                             

Epoch 29/30, Train Loss: 0.4239, Train Acc: 0.8502, Val Loss: 0.4260, Val Acc: 0.8562


                                                             

Epoch 30/30, Train Loss: 0.5384, Train Acc: 0.8112, Val Loss: 0.5372, Val Acc: 0.8185
NetDiffConv6 model finished training with best validation accuracy of 0.8562


VBox(children=(Label(value='0.111 MB of 0.111 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_acc,▁▄▅▅▆▆▆▆▇▆▆▆▆▇▇▇▇▇▇▇▇████████▇
train_loss,█▅▅▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂
val_acc,▁▃▄▅▅▆▆▆▆▅▆▆▆▆▇▇▇▇▇▇▇████████▇
val_loss,█▆▅▄▄▃▃▃▃▄▃▃▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▂

0,1
train_acc,0.81124
train_loss,0.53841
val_acc,0.8185
val_loss,0.53723


In [18]:
wandb.finish()