In [14]:
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder, CIFAR100
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
import wandb
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'

In [15]:
import ssl

train_transform=T.Compose([#T.Resize(32),
                           T.RandomCrop(32, padding=4, padding_mode='reflect'),
                           T.RandomHorizontalFlip(),
                           T.ToTensor(),T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

valid_transform=T.Compose([T.ToTensor(),T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

torch.manual_seed(1)

train_ds=CIFAR100(root="data",train=True,transform=train_transform,download=False)
valid_ds=CIFAR100(root="data",train=False,transform=valid_transform,download=False)

In [16]:
batch_size = 256



In [17]:
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)

In [45]:

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam,index


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)



def mix_freq(img,label,b_kernel,e_kernel,lamB=0.5,lamE=0.5):
    with torch.no_grad():
        print(img)
        blured_img=b_kernel(img)
        edged_img=e_kernel(img)[1].expand(-1,3,-1,-1)
        idx = torch.randperm(blured_img.shape[0])
        edged_img=edged_img[idx]
        y_b,y_e=label,label[idx]
        mixed_img=lamB*blured_img+lamE*edged_img#+0.2*img
    return mixed_img,y_b,y_e,lamB,lamE,idx
    
def mix_freq_classic(img,label,b_kernel,alpha=1):
    lam=np.random.beta(alpha,alpha)
    with torch.no_grad():
        edged_img=img-b_kernel(img)
    idx = torch.randperm(img.shape[0])
    edged_img=edged_img[idx]
    y_a,y_e=label,label[idx]
    mixed_img=lam*img+(1-lam)*edged_img
    ##+0.2*img
    return mixed_img,y_a,y_e,lam

def mixfreq_criterion(criterion, pred, y_a, y_b):
    return  criterion(pred, y_a) + criterion(pred, y_b)


In [46]:
import torch.nn as nn

In [47]:
from torchvision.models.resnet import conv3x3, _resnet

class PreactBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(PreactBasicBlock, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')

        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1

        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)

        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        return out


In [48]:
model = _resnet('resnet18', PreactBasicBlock, [2, 2, 2, 2], False, False)
model.fc = nn.Linear(model.fc.in_features, 100)


In [49]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        images=images.to(self.device)
        labels=labels.to(self.device)
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels)
        return loss
    def training_step_mixup(self, batch,epoch,criterion):
        images, labels = batch 
        images=images.to(self.device)
        labels=labels.to(self.device)
        if epoch<550:
            mixed_x, y_a, y_b, lam,index=mixup_data(images,labels)
            out = self(mixed_x)                  # Generate predictions
            loss = mixup_criterion(criterion,out,y_a,y_b,lam)
        else:
            out = self(images)                  # Generate predictions
            loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    def training_step_mixfreq(self, batch,epoch,criterion,b_kernel,e_kernel):
        images, labels = batch 
        images=images.to(self.device)
        labels=labels.to(self.device)
        if epoch<250:
            mixed_x, y_a, y_b, lamB,lamE,index=mix_freq(images,labels,b_kernel,e_kernel)
            #mixed_x=mixed_x.to(self.device)
            out = self(mixed_x)                  # Generate predictions
            loss = mixfreq_criterion(criterion,out,y_a,y_b)
        else:
            out = self(images)                  # Generate predictions
            loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        
        images, labels = batch 
        images=images.to(self.device)
        labels=labels.to(self.device)
 #       print(labels)
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        wandb.log({
        "epoch": epoch,
        "Train Loss": result['train_loss'],
        "Val Accuracy": result['val_acc'],
        "Val Loss": result['val_loss'],
        "Learning Rate": result['lrs'][-1]},
            )
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))

In [50]:
class Base_Model(ImageClassificationBase):
    def __init__(self,backbone,device):
        super().__init__()
        
        self.backbone=backbone
        self.device=device
    def forward(self, xb):
        out = self.backbone(xb)
        return out

In [51]:
from torchvision.models.resnet import resnet18

In [54]:
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD,blur_kernel=None,edge_kernel=None):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step_mixfreq(batch,epoch,nn.CrossEntropyLoss(),blur_kernel,edge_kernel)
            return batch
            #loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [55]:
history = [evaluate(model, valid_dl)]
history

[{'val_loss': 4.848124980926514, 'val_acc': 0.010880054906010628}]

In [56]:

model=Base_Model(model,device='cuda')
model.to('cuda')
print("MDODEL DEVICE")
from kernel import CannyFilter,get_thin_kernels,get_gaussian_kernel
blur_kernel=CannyFilter(k_gaussian=3,k_sobel=3,blur=True).eval()
blur_kernel.to('cuda')
edge_kernel=CannyFilter(k_gaussian=1,k_sobel=3,blur=False).eval()
edge_kernel.to('cuda')


MDODEL DEVICE


CannyFilter(
  (gaussian_filter): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (sobel_filter_x): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (sobel_filter_y): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (directional_filter): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (hysteresis): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

In [57]:
for batch in train_dl:
    model.training_step_mixfreq(batch,5,nn.CrossEntropyLoss(),blur_kernel,edge_kernel)
    break

tensor([[[[ 0.9132,  0.6734,  0.2111,  ...,  1.7694,  1.9578,  2.1462],
          [ 0.6563,  1.0331,  1.0331,  ...,  1.8208,  2.0263,  2.1633],
          [ 0.3652,  1.0502,  0.7933,  ...,  1.8379,  2.0434,  2.1804],
          ...,
          [-1.4843, -1.0733, -0.6452,  ...,  0.1768,  0.1254,  0.1597],
          [-1.2617, -0.9534, -0.6452,  ...,  0.2282,  0.1768,  0.1939],
          [-1.1075, -0.9534, -0.7479,  ...,  0.2624,  0.2453,  0.1939]],

         [[ 1.0280,  0.8354,  0.4328,  ...,  1.9734,  2.1835,  2.3235],
          [ 0.9230,  1.2031,  1.1331,  ...,  2.0434,  2.2360,  2.3235],
          [ 0.7304,  1.3957,  1.0280,  ...,  2.0784,  2.2535,  2.3585],
          ...,
          [-1.4230, -0.9853, -0.4951,  ...,  0.2402, -0.1099, -0.1450],
          [-1.1429, -0.8277, -0.5301,  ...,  0.0301, -0.1450, -0.0924],
          [-0.9503, -0.8277, -0.6176,  ..., -0.1099, -0.0924, -0.1099]],

         [[ 1.1585,  1.0017,  0.6879,  ...,  2.2740,  2.4308,  2.5354],
          [ 1.3154,  1.3328,  

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

In [18]:
epochs = 300
max_lr = 0.01
grad_clip = 0.1
weight_decay = 0.5e-4
opt_func = torch.optim.Adam
wandb.init(project="Augmentation_Strategy_CIFAR100")
wandb.run.name = "MixFreq_PreAct_ResNet18_0.01_300"

[34m[1mwandb[0m: Currently logged in as: [33mhslrock[0m (use `wandb login --relogin` to force relogin)


In [25]:
%%time
history += fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func,)

tensor([[[[-0.1314,  0.0056, -0.0972,  ...,  1.4098,  1.2557,  1.0844],
          [ 0.0227,  0.0398,  0.1083,  ...,  1.4269,  1.3070,  1.1529],
          [ 0.0912,  0.2453,  0.2796,  ...,  1.5639,  1.4783,  1.2728],
          ...,
          [ 0.8789,  1.1529,  0.3823,  ...,  0.8789,  0.7933,  0.5022],
          [ 0.1939,  0.2624, -0.1828,  ...,  1.5639,  1.4612,  1.2214],
          [ 0.5536,  0.6392,  0.6392,  ...,  1.8722,  1.8208,  1.8037]],

         [[-0.0399,  0.1001,  0.0126,  ...,  1.2206,  1.0980,  0.9230],
          [ 0.0826,  0.1001,  0.1877,  ...,  1.1155,  1.0105,  0.9405],
          [ 0.1001,  0.3102,  0.3452,  ...,  1.2381,  1.1331,  1.0280],
          ...,
          [ 0.5903,  0.8354,  0.1702,  ...,  0.4503,  0.3803,  0.2227],
          [ 0.0126,  0.0651, -0.3025,  ...,  1.0105,  0.9580,  0.8179],
          [ 0.3803,  0.4678,  0.4853,  ...,  1.2031,  1.2031,  1.2381]],

         [[-0.7064, -0.5321, -0.7238,  ..., -0.0964, -0.2010, -0.4101],
          [-0.6890, -0.6367, -

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

## VANILA RESNET18 (NO AUG)

Epoch [94], last_lr: 0.00013, train_loss: 0.3673, val_loss: 1.6711, val_acc: 0.6126  
Epoch [95], last_lr: 0.00008, train_loss: 0.3603, val_loss: 1.6710, val_acc: 0.6131  
Epoch [96], last_lr: 0.00005, train_loss: 0.3479, val_loss: 1.6732, val_acc: 0.6127  
Epoch [97], last_lr: 0.00002, train_loss: 0.3484, val_loss: 1.6774, val_acc: 0.6133  
Epoch [98], last_lr: 0.00001, train_loss: 0.3424, val_loss: 1.6757, val_acc: 0.6131  
Epoch [99], last_lr: 0.00000, train_loss: 0.3499, val_loss: 1.6706, val_acc: 0.6132  

## VANILLA RESNET18 (MIXUP)
Epoch [595], last_lr: 0.00000, train_loss: 0.7558, val_loss: 1.3887, val_acc: 0.6273   
Epoch [596], last_lr: 0.00000, train_loss: 0.7481, val_loss: 1.3862, val_acc: 0.6294  
Epoch [597], last_lr: 0.00000, train_loss: 0.7560, val_loss: 1.3890, val_acc: 0.6275  
Epoch [598], last_lr: 0.00000, train_loss: 0.7582, val_loss: 1.3909, val_acc: 0.6296  
Epoch [599], last_lr: 0.00000, train_loss: 0.7508, val_loss: 1.3876, val_acc: 0.6293  