In [1]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
import segmentation_models_pytorch as smp
from collections import OrderedDict
import numpy as np
import copy

import os
from collections import OrderedDict
import json
import time

import torch
from torchvision import datasets, transforms
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torchvision.models as models
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from PIL import Image
from torchinfo import summary
import struct
import gzip

# Process Data

In [33]:
transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.Grayscale(num_output_channels=1),
                                transforms.Normalize((0.5), (0.5))])
data = torchvision.datasets.MNIST('./data/',  transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(data,
                                          batch_size=1,
                                          shuffle=True)

In [61]:
all_data = {}
for i, batch in enumerate(data_loader):
    x, y = batch[0][0], batch[1].numpy()[0]
    
    if y not in all_data:
        all_data.update({y:[x]})
    else:
        curr = all_data[y]
        curr.append(x)
        all_data.update({y:curr})


In [65]:
train = {}
test = {} 

for label in all_data:
    test.update({label:all_data[label][0:1000]})
    train.update({label:all_data[label][1000:]})

In [67]:
even_idxs = [0, 2, 4, 6, 8]
odd_idxs = [1, 3, 5, 7, 9]

In [68]:
even_train={} 
odd_train = {}
even_test = {}
odd_test = {} 


In [69]:
for label in all_data:
    if label in odd_idxs:
        odd_test.update({label:all_data[label][0:1000]})
        odd_train.update({label:all_data[label][1000:]})
    else:
        even_test.update({label:all_data[label][0:1000]})
        even_train.update({label:all_data[label][1000:]})

In [72]:
from torchvision.utils import save_image

for label in train:
    path = './data/train/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = train[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

In [73]:
for label in test:
    path = './data/test/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = test[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

In [74]:
for label in even_test:
    path = './data/eval/evens/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = even_test[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

In [75]:
for label in odd_test:
    path = './data/eval/odds/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = odd_test[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

In [76]:
for label in even_train:
    path = './data/ext/evens/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = even_train[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

In [77]:
for label in odd_train:
    path = './data/ext/odds/'
    if not os.path.exists(path):
        os.mkdir(path) 
    if not os.path.exists(path + str(label) + '/'):
        os.mkdir(path + str(label) + '/') 
    curr = odd_train[label]
    # print(curr)
    for i, image in enumerate(curr):
        # print(i)
        save_image(image, path + str(label) + '/' + str(i) + '.png')

# Define Model

In [2]:
import numpy as np
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
    #overwrite the forward function for pruning
    def forward(self, input):
        return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)
    #overwrite the forward function for pruning
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 =  Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                 Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 =  Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear =  Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def apply_mask(self, mask, sizing):
        start = 0
        copy_state = copy.deepcopy(self.state_dict())
        segments = {}
        for i in sizing:
            base, seg_idx = i.split('_')[0], int(i.split('_')[1])
            quart_size = int(copy_state[base].shape[1]/4)
            end = start + sizing[i]
            segment = np.round(mask[start:end])
            index = np.where(segment == 0)
            if seg_idx == 1:
                copy_state[base].data[index, 0:quart_size] = 0
                # print("------------")
                # print(copy_state[base].data.shape)
                # print(copy_state[base].data[index].shape)
                # print(copy_state[base].data[index, 0:quart_size].shape)
            elif seg_idx == 2:
                copy_state[base].data[index, quart_size:quart_size*2] = 0
            elif seg_idx == 3:
                copy_state[base].data[index, quart_size*2:quart_size*3] = 0

            elif seg_idx == 4:
                copy_state[base].data[index, quart_size*3:] = 0
            if base not in segments:
                segments.update({base:{seg_idx:index}})
            else:
                curr = segments[base]
                curr.update({seg_idx:index})
                segments.update({base:curr})
            # else:

            start = end

        self.load_state_dict(copy_state)
        self.segments = segments

    def half(self):
        for name, param in self.named_parameters():
            param.data = param.data.half()


    def return_model_state(self):
        return self.state_dict()

    def revert_weights(self):
        self.load_state_dict(self.weights_backup)
        for name, param in self.named_parameters():
            param.requires_grad = True

    def update_backup(self):
        self.weights_backup = copy.deepcopy(self.state_dict())


def Resnet(type_id, num_classes):
    if(type_id==18):  net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
    if(type_id==34):  net = ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
    if(type_id==50):  net = ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
    if(type_id==101):  net = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
    if(type_id==152):  net = ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
    return net


# Define Train Loop

In [12]:
def finetune(model, extloader, evalloader, epochs=5):
    criterion = nn.CrossEntropyLoss().to(DEVICE_1)
    lmbda = lambda epoch: 0.95
    best_acc = 0
    average_acc = 0
    optim = torch.optim.Adam(model.parameters(), lr=0.0005)
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optim, lr_lambda=lmbda)
    best_state = None
    for epoch in range(epochs):
        model.train()
        acc_per_epoch = 0
        for i, batch in enumerate(extloader):
            optim.zero_grad()
            x, y = batch[0].cuda(0), batch[1].cuda(0)
            fx = model(x)
            loss = criterion(fx.squeeze(), y)
            loss.backward()
            optim.step()
            x = x.cpu()
            y = y.cpu()
        acc_per_epoch = test_acc(model, evalloader)
        scheduler.step()
        if acc_per_epoch > best_acc:
            best_acc = acc_per_epoch
            best_state = copy.deepcopy(model.state_dict())
        print('Epoch: ', epoch, 'Acc: ', acc_per_epoch)
    return best_acc, best_state

def test_acc(model, evalloader):
    model.eval()
    avg_acc = 0
    with torch.no_grad():
        for i, data in enumerate(evalloader):
            x, y = data[0].cuda(0), data[1].cuda(0)
            fx = model(x)
            _, predicted = fx.max(1)
            acc_per_batch = 100. * predicted.eq(y).sum().item() / y.size(0)
            avg_acc += acc_per_batch
    avg_acc = avg_acc/len(evalloader)
    return avg_acc

# Train

In [13]:
path = './data/train/'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5), (0.5))])
dataset = datasets.ImageFolder(path, transform=transform)
trainloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [14]:
path = './data/test/'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5), (0.5))])
dataset = datasets.ImageFolder(path, transform=transform)
testloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [15]:
DEVICE_1 = 'cuda:0'

In [16]:
model = ResNet(BasicBlock, [2, 2, 2, 2], 10)
# model.cuda()

In [18]:
model.cuda()
acc, model=finetune(model, trainloader, testloader)

Epoch:  0 Acc:  96.33386948529412
Epoch:  1 Acc:  98.43347886029412
Epoch:  2 Acc:  98.10259650735294
Epoch:  3 Acc:  96.46426930147058
Epoch:  4 Acc:  99.03664981617648


In [19]:
torch.save(model, './models/base.pth')