In [1]:
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [2]:
DATA_PATH = "./data/"
seed = 42
device = "cuda"

In [3]:
df =  pd.read_csv(DATA_PATH +  "train.csv")

In [4]:
y = df["label"].values
X = df.drop("label", axis=1).values
X = X.reshape((len(X), 1,  28, 28))

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, random_state=seed)

In [5]:
class MNIST(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).long()
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx, :, :, :], self.y[idx]

In [6]:
batch_size = 128
num_workers = 4

train_dataset = MNIST(X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)

val_dataset = MNIST(X_val, y_val)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [7]:
class BNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_1 = BNLayer(in_channels, in_channels)
        self.conv_2 = BNLayer(in_channels, out_channels, stride=2)
    
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        return x
    
class Net(nn.Module):
    def __init__(self, channels, num_classes=10):
        super().__init__()
        
        self.conv_blocks = []
        for i in range(len(channels) - 1):
            self.conv_blocks.append(Block(channels[i], channels[i+1]))
        self.conv_blocks = nn.Sequential(*self.conv_blocks)
        
        self.linear = nn.Linear(32, num_classes)
    
    def forward(self, x):
        x = self.conv_blocks(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(len(x), -1)
        x = self.linear(x)
        return x

In [8]:
torch.manual_seed(seed)

channels = [1, 8, 16, 32]
model = Net(channels).to(device)

In [9]:
num_epochs = 40
lr = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fct = nn.CrossEntropyLoss()

In [10]:
def calculate_accuracy(y_pred, y_true):
    y_pred = F.softmax(y_pred, dim=1)
    y_pred = torch.argmax(y_pred, dim=1)
    correct =  y_pred  == y_true
    return torch.mean(correct.float()).numpy()

In [11]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

for epoch in range(num_epochs):
    train_loss, val_loss =  0.0, 0.0    
    val_acc = 0.0
    
    model.train()
    for _, (x, y) in enumerate(train_dataloader):
        model.zero_grad()
        y, x  = y.to(device), x.to(device)
        
        y_hat = model(x)
        batch_loss = loss_fct(y_hat, y)
        batch_loss.backward()
        optimizer.step()
        
        train_loss += batch_loss.detach().cpu().numpy() / len(y)

    model.eval()
    for _, (x, y) in enumerate(val_dataloader):
        y, x  = y.to(device), x.to(device)
        y_hat = model(x)
        batch_loss = loss_fct(y_hat, y)
        val_loss += batch_loss.detach().cpu().numpy() / len(y)
        val_acc += calculate_accuracy(y_hat.detach().cpu(), y.cpu())
    
    train_loss = np.round(train_loss / len(train_dataloader), 6)
    val_loss = np.round(val_loss / len(val_dataloader), 6)
    val_acc = np.round(val_acc / len(val_dataloader), 6)
    
    if epoch % 5 == 0:
        print(f"-------- Epoch {epoch} --------")
        print(f"Train loss: {train_loss}")
        print(f"Val loss: {val_loss}")
        print(f"Val acc: {val_acc}")

-------- Epoch 0 --------
Train loss: 0.016511
Val loss: 0.014961
Val acc: 0.491398
-------- Epoch 5 --------
Train loss: 0.006096
Val loss: 0.005539
Val acc: 0.890478
-------- Epoch 10 --------
Train loss: 0.002628
Val loss: 0.002486
Val acc: 0.949777
-------- Epoch 15 --------
Train loss: 0.001524
Val loss: 0.001486
Val acc: 0.96264
-------- Epoch 20 --------
Train loss: 0.001052
Val loss: 0.001086
Val acc: 0.968581
-------- Epoch 25 --------
Train loss: 0.000797
Val loss: 0.000883
Val acc: 0.971816
-------- Epoch 30 --------
Train loss: 0.00064
Val loss: 0.000755
Val acc: 0.97318
-------- Epoch 35 --------
Train loss: 0.000532
Val loss: 0.000665
Val acc: 0.975728


In [16]:
[module for module in model.modules() if type(module) == nn.Conv2d]

[Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
 Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)]

In [17]:
prune_level = 0.5

In [18]:
from torch.nn.utils import prune

for module in model.modules():
    if type(module) == nn.Conv2d:
        prune.ln_structured(module=module, name='weight', amount=prune_level, n=2, dim=1)

In [22]:
module.weight_mask

AttributeError: 'Linear' object has no attribute 'weight_mask'

In [23]:
module

Linear(in_features=32, out_features=10, bias=True)

In [37]:
for module in model.modules():
    if type(module)  == nn.Conv2d:
        print(module.weight_mask.shape)
        break

torch.Size([1, 1, 3, 3])


In [34]:
list(model.modules())[0]

Net(
  (conv_blocks): Sequential(
    (0): Block(
      (conv_1): BNLayer(
        (conv): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
      (conv_2): BNLayer(
        (conv): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
    )
    (1): Block(
      (conv_1): BNLayer(
        (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
      (conv_2): BNLayer(
        (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr