# <center> IMPORT MODULES </center> 

In [None]:
import torch
import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

import wandb
from wandb.keras import WandbCallback

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# <center> DIFFERENTIAL CONVOLUTION </center>

In [None]:
# [-765 to 255]  --> [-255 to 85] --> [0 to 255]
class DiffConv6_1(nn.Module):

    def __init__(self):
        super(DiffConv6_1, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
        
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*6,sx,sy, device = x.device)
        dmax = x.max()
        dmin = x.min()
        
        for i in range(0,ins):


            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            #area = self.output[i,oM:oN,0:sx-1,0:sy]
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            #area = ptr[oM:oN,0:sx,0:sy-1]
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # SOUTH WEST
            oM = 4*n
            oN = 5*n
           
            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            # RIGHT - SOUTH EAST - BOTTOM
            oM = 5*n
            oN = 6*n
           
            #area = ptr[oM:oN,0:sx-1,0:sy-1]
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
                        
            # Scale: [-765 to 255]  --> [-255 to 85]
            self.output[i,oM:oN,0:sx-1,0:sy-1] = ( (self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach()) / 3 ).int()
            

        oM = n
        oN = 6*n

           
        #self.signInputs = self.output.sign()
        #self.signInputs[0:ins,0:n,:,:] = torch.ones(ins,n,sx,sy)
        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #self.output[0:ins,oM:oN,:,:] = torch.sqrt(self.output[0:ins,oM:oN] * self.output[0:ins,oM:oN].clone().detach())

        #print(self.output)

        return self.output

In [None]:
# [-765 to 765]  --> [-255 to 255] --> [0 to 255]

class DiffConv6_2(nn.Module):

    def __init__(self):
        super(DiffConv6_2, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
        
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*6,sx,sy, device = x.device)
        dmax = x.max()
        dmin = x.min()
        
        for i in range(0,ins):


            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # SOUTH WEST
            oM = 4*n
            oN = 5*n
           
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            # RIGHT - SOUTH EAST - BOTTOM
            oM = 5*n
            oN = 6*n
           
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(3*x[i,0:n,0:sx-1,0:sy-1]) # current pixel * 3
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # [-765 to 765]  --> [-255 to 255] --> [0 to 255]
            self.output[i,oM:oN,0:sx-1,0:sy-1] = ( (self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach()) / 3 ).int()
            

        oM = n
        oN = 6*n

           
        #self.signInputs = self.output.sign()
        #self.signInputs[0:ins,0:n,:,:] = torch.ones(ins,n,sx,sy)
        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #self.output[0:ins,oM:oN,:,:] = torch.sqrt(self.output[0:ins,oM:oN] * self.output[0:ins,oM:oN].clone().detach())

        #print(self.output)

        return self.output

In [None]:
# [-765 to 255]  --> [-510 to 510] --> [-255 to 255] --> [0 to 255]

class DiffConv6_3(nn.Module):

    def __init__(self):
        super(DiffConv6_3, self).__init__()

    def forward(self, x):
        sizeofin = x.size()
        
        ins = sizeofin[0]
        n = sizeofin[1]
        sx = sizeofin[2]
        sy = sizeofin[3]  

        self.output = torch.zeros(ins,n*6,sx,sy, device = x.device)
        dmax = x.max()
        dmin = x.min()
        
        for i in range(0,ins):


            # ORIGINAL INPUT
            oM = 0
            oN = n
            
            self.output[i,oM:oN,0:sx,0:sy]= x[i].clone()
            
            # HORIZONTAL DIFFERENCE
            oM = n
            oN = 2*n

            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(x[i,0:n,0:sx-1,0:sy])
            self.output[i,oM:oN,0:sx-1,0:sy] = self.output[i,oM:oN,0:sx-1,0:sy].add(-x[i,0:n,1:sx,0:sy])
            
            # VERTICAL DIFFERENCE
            oM = 2*n
            oN = 3*n

            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(x[i,0:n,0:sx,0:sy-1])
            self.output[i,oM:oN,0:sx,0:sy-1]=self.output[i,oM:oN,0:sx,0:sy-1].add(-x[i,0:n,0:sx,1:sy])
            
            # SOUTH EAST
            oM = 3*n
            oN = 4*n

            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # SOUTH WEST
            oM = 4*n
            oN = 5*n
           
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1]=self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            
            # RIGHT - SOUTH EAST - BOTTOM
            oM = 5*n
            oN = 6*n
           
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(x[i,0:n,0:sx-1,0:sy-1]) 
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,0:sy-1])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,0:sx-1,1:sy])
            self.output[i,oM:oN,0:sx-1,0:sy-1] = self.output[i,oM:oN,0:sx-1,0:sy-1].add(-x[i,0:n,1:sx,1:sy])
            
            # [-765 to 255]  --> [-510 to 510] --> [0 to 255]
            self.output[i,oM:oN,0:sx-1,0:sy-1] = ( ((self.output[i,oM:oN,0:sx-1,0:sy-1].clone().detach()) + 255) / 2).int()
            

        oM = n
        oN = 6*n

           
        #self.signInputs = self.output.sign()
        #self.signInputs[0:ins,0:n,:,:] = torch.ones(ins,n,sx,sy)
        
        #Inplace error came from the line below
        #self.output[0:ins,oM:oN,:,:] = self.output[0:ins,oM:oN].abs()
        self.output[0:ins,oM:oN,:,:] = torch.abs(self.output[0:ins,oM:oN].clone().detach())
        
        #self.output[0:ins,oM:oN,:,:] = torch.sqrt(self.output[0:ins,oM:oN] * self.output[0:ins,oM:oN].clone().detach())

        #print(self.output)

        return self.output

# <center> DATASET </center>

In [None]:
# Define transforms for data preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

# Load CIFAR10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)


# <center> Network Architectures </center>

In [None]:
multiplier = 4

# CNN with DiffConv6_1
class NetDiffConv6_1(nn.Module):
    def __init__(self):
        super(NetDiffConv6_1, self).__init__()
        
        self.diff = DiffConv6_1()
        
        self.conv1 = nn.Conv2d(18, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


# CNN with DiffConv6_2
class NetDiffConv6_2(nn.Module):
    def __init__(self):
        super(NetDiffConv6_2, self).__init__()
        
        self.diff = DiffConv6_2()
        
        self.conv1 = nn.Conv2d(18, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

# CNN with DiffConv6_3
class NetDiffConv6_3(nn.Module):
    def __init__(self):
        super(NetDiffConv6_3, self).__init__()
        
        self.diff = DiffConv6_3()
        
        self.conv1 = nn.Conv2d(18, 16*multiplier, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16*multiplier)
        self.pool = nn.MaxPool2d(2, 2) 
        
        self.conv2 = nn.Conv2d(16*multiplier, 32*multiplier, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32*multiplier)
        
        self.conv3 = nn.Conv2d(32*multiplier, 64*multiplier, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64*multiplier)
        
        self.fc1 = nn.Linear(64*multiplier * 4 * 4, 256*multiplier)
        self.dropout1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256*multiplier, 128*multiplier)
        self.dropout2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(128*multiplier, 10)
        
    def forward(self, x):
        x = self.diff(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x) # 16x16
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x) # 8x8
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x) # 4x4

        x = x.view(-1, 64*multiplier * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


# <center> WandB Config </center>

In [None]:
# Set up Weights and Biases
wandb.login()

In [None]:
sweep_config = {
    'method': 'grid',
    
    'metric' : {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    
    'parameters' : {
        'model' : {
            'values': ['NetDiffConv6_1', 'NetDiffConv6_2', 'NetDiffConv6_3']
                  }
    }
    
    }

pprint.pprint(sweep_config)

In [None]:
sweep_id = wandb.sweep(sweep_config, project="DiffConv6_Cifar10_Test_0")

# <center> TRAINING ++ </center>

In [None]:
# Train loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in tqdm(dataloader, desc='Training', leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # compute statistics
        running_loss += loss.item() * inputs.size(0)
        predicted = outputs.argmax(dim=1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

# Test loop
def test(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc='Testing', leave=False):
            inputs, targets = inputs.to(device), targets.to(device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # compute statistics
            running_loss += loss.item() * inputs.size(0)
            predicted = outputs.argmax(dim=1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
batch_size = 16
learning_rate = 0.0001
num_epochs = 50
# optimizer = 'Adam'
lr_step_size = 10
lr_gamma = 0.1
betas = (0.85, 0.999)
amsgrad = True

In [None]:
def main(config = None):
    
    # Define data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    # Train and evaluate models
    with wandb.init(config = config):
        config = wandb.config
        model_name = config.model

        if config.model == 'NetDiffConv6_1':
            model = NetDiffConv6_1().to(device)
        elif config.model == 'NetDiffConv6_2':
            model = NetDiffConv6_2().to(device)
        elif config.model == 'NetDiffConv6_3':
            model = NetDiffConv6_3().to(device)

        print(f'Training {model_name} model...')

        optimizer = optim.Adam(model.parameters(),
                               lr=learning_rate,
                               betas = betas,
                               amsgrad = amsgrad,
                              )

    # LEARNING RATE SCHEDULERS
    #     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=num_epochs//3, T_mult=2, eta_min=0)

        criterion = nn.CrossEntropyLoss()
        best_val_acc = 0.0
        for epoch in range(1, num_epochs+1):
            scheduler.step()
            train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
            val_loss, val_acc = test(model, test_loader, criterion, device)
            print(f'Epoch {epoch}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            # log results to weights and biases
            wandb.log({'train_loss': train_loss,
                       'train_acc': train_acc,
                       'val_loss': val_loss,
                       'val_acc': val_acc})
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'{model_name}.pt')
        print(f'{model_name} model finished training with best validation accuracy of {best_val_acc:.4f}')


In [None]:
wandb.agent(sweep_id, main, count=3)

In [None]:
# import time
# import requests

# while True:
#     try:
#         # Replace the URL with the Kaggle page you want to keep alive
#         requests.get("https://www.kaggle.com/your-username/your-notebook")
#     except:
#         pass
    
#     # Adjust the sleep time to your liking, but keep it under 40 minutes
#     time.sleep(300) # 5 minutes