In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_ds = datasets.MNIST(root='./data', train=True,  download=False, transform=transform)
test_ds = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

In [None]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

In [None]:
imgs, labels = next(iter(train_loader))
imgs.shape, imgs.nelement(), labels, set(sorted(labels.tolist()))

In [None]:
plt.imshow(imgs[0][0], cmap="grey")

In [None]:
out_channels = 8
in_channels = 1 # grey images
kernel_size = 3

weight = torch.randn((out_channels, in_channels, kernel_size, kernel_size), dtype=torch.float) * 0.1
bias = torch.zeros(out_channels)
weight.shape

In [None]:
Kh, Kw = 4, 6
Sh, Sw = 2, 1
Ph, Pw = 2, 3

B, C, H, W = 2, 3, 27, 32

unfold = torch.nn.Unfold(kernel_size=(Kh, Kw), stride=(Sh, Sw), padding=(Ph, Pw))
input = torch.randn(B, C, H, W)
output = unfold(input)

print(output.shape)
output.shape == torch.Size([B, C*Kh*Kw, ((H+2*Ph-Kh) // Sh + 1) * ((W+2*Pw-Kw) // Sw + 1)])

In [None]:
patches = F.unfold(imgs, kernel_size=(3,3), stride=1)
patches.shape, patches.nelement()

In [None]:
w_flat = weight.view(8, in_channels*(kernel_size**2))
w_flat.shape

In [None]:
conv = w_flat @ patches + bias.unsqueeze(1) # [64, 8, 9] * [64, 9, 676] -> [64, 8, 676] + [8, 1]
conv.shape, conv.nelement()

In [None]:
# recover the conv to feature map (B, C', H', W')
conv = conv.view(conv.shape[0], 8, 26, 26)
conv.shape

In [None]:
conv = F.relu(conv)
conv = F.max_pool2d(conv, kernel_size=2)
conv.shape

In [None]:
conv = conv.view(conv.shape[0], -1)
conv.shape

In [None]:
w_out = torch.randn(8*13*13, 10)
logits = conv @ w_out
logits.shape

In [None]:
class ScratchCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(8, 1, 3, 3) * 0.1)
        self.bias = torch.nn.Parameter(torch.zeros(8))
        self.fc = torch.nn.Linear(8 * 13 * 13, 10)

    def forward(self, x):
        patches = F.unfold(x, kernel_size=3, stride=1)
        w_flat = self.weight.view(8, -1)

        conv = w_flat @ patches + self.bias.unsqueeze(1)

        B = x.size(0)
        conv = conv.view(B, 8, 26, 26)

        conv = F.relu(conv)
        conv = F.max_pool2d(conv, kernel_size=2)

        conv = conv.view(B, -1)
        out = self.fc(conv)
        return out

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ScratchCNN().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = torch.nn.CrossEntropyLoss() 

In [None]:
model.train()
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
logits.shape, labels.shape
loss = criterion(logits, labels)
loss.item()

In [None]:
optim.zero_grad()
loss.backward()
optim.step()

In [None]:
optim.param_groups[0]['lr']

In [None]:
epochs = 2
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item() * imgs.size(0)

        # wandb.log({
        #     'train_loss': loss,
        #     'lr': optim.param_groups[0]['lr']
        # })

    print(f"Epoch {epoch+1}/{epochs}  Train Loss: {total_loss/len(train_loader.dataset):.4f}")

In [None]:
model.eval()
correct = 0

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        pred = logits.argmax(dim=1)
        correct += (pred == labels).sum().item()

        # acc = correct / len(test_loader.dataset)
        # wandb.log({
        #     'val_acc': acc,
        # })

acc = correct / len(test_loader.dataset)
print(f"Test Accuracy: {acc*100:.2f}%")

In [None]:
imgs, labels = next(iter(test_loader))
imgs = imgs[:5]
labels = labels[:5]

In [None]:
plt.imshow(imgs[0][0])

In [None]:
plt.imshow(imgs[1][0])

In [None]:
plt.imshow(imgs[2][0])

In [None]:
plt.imshow(imgs[3][0])

In [None]:
plt.imshow(imgs[4][0])

In [None]:
imgs[0].unsqueeze(0).shape

In [None]:
with torch.no_grad():
    logits = model(imgs)
    
pred = F.softmax(logits, dim=1)
plt.figure(figsize=(8,2))
plt.imshow(pred, cmap='grey')

In [None]:
import wandb
wandb.login()
wandb.init(project='scratch-cnn', name='exp2')

let's use torch.nn.conv

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_ds = datasets.MNIST(root='./data', train=True,  download=False, transform=transform)
test_ds  = datasets.MNIST(root='./data', train=False, download=False, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(8 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv(x) # [B,8,28,28]
        x = F.relu(x)
        x = F.max_pool2d(x, 2) # [B,8,14,14]
        x = x.view(x.size(0), -1) # flatten
        x = self.fc(x) # [B,10]
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [6]:
epochs = 2
for epoch in range(epochs):
    model.train()
    running_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    avg_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}")

Epoch 1/2, Train Loss: 0.2698
Epoch 2/2, Train Loss: 0.1081


In [7]:
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        pred = logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
acc = correct / len(test_loader.dataset)
print(f"Test Accuracy: {acc*100:.2f}%")

Test Accuracy: 97.19%


now we try with 3d conv and 2+1d conv

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [1]:
batch_size = 8
clip_len = 16
height = 112
width = 112

In [4]:
random_data = torch.randn(batch_size, 3, clip_len, height, width)
random_data.shape # B, C, T, H, W

torch.Size([8, 3, 16, 112, 112])

In [47]:
with torch.no_grad():
    x = torch.randn(16, 512)
    bn = nn.BatchNorm1d(512)
    out = bn(x)
    print(out[:,0].mean(), out[:,0].std())

    x = torch.randn(16, 512)
    ln = nn.LayerNorm(512)
    out = ln(x)
    print(out[0,:].mean(), out[0,:].std())

    x = torch.randn(16, 3, 32, 28) # 4D
    bn = nn.BatchNorm2d(3)
    out = bn(x)
    print(out[:,0].mean(), out[:,0].std())

    x = torch.randn(16, 3, 512, 32, 28) # 5D
    bn = nn.BatchNorm3d(3)
    out = bn(x)
    print(out[:,0].mean(), out[:,0].std())

tensor(-2.9802e-08) tensor(1.0328)
tensor(-1.6764e-08) tensor(1.0010)
tensor(1.2772e-08) tensor(1.0000)
tensor(2.6609e-10) tensor(1.0000)


In [48]:
conv1 = nn.Conv3d(in_channels=3, out_channels=16, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1))
bn1 = nn.BatchNorm3d(16)

In [49]:
x = random_data
x = conv1(x)
x = bn1(x)
x.shape

torch.Size([8, 16, 16, 112, 112])

In [None]:
pool1 = nn.MaxPool3d(kernel_size=(2,2,2), stride=(2,2,2))
x = pool1(x)
x.shape

torch.Size([8, 16, 8, 56, 56])

In [51]:
conv2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1))
bn2 = nn.BatchNorm3d(32)

x = conv2(x)
x = bn2(x)
x.shape

torch.Size([8, 32, 8, 56, 56])

In [52]:
pool2 = nn.MaxPool3d(kernel_size=(2,2,2), stride=(2,2,2))
x = pool2(x)
x.shape

torch.Size([8, 32, 4, 28, 28])

In [None]:
global_pool = nn.AdaptiveAvgPool3d((1,1,1))
x = global_pool(x)
x.shape

torch.Size([8, 32, 1, 1, 1])

In [54]:
x = x.view(x.size(0), -1)
x.shape

torch.Size([8, 32])

In [57]:
fc = nn.Linear(32, 10)
x = fc(x)
x.shape

torch.Size([8, 10])

In [58]:
x = torch.randn(16, 3, 16, 112, 112)
spatial_conv = nn.Conv3d(3, 16, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
bn1 = nn.BatchNorm3d(16)
temp_conv = nn.Conv3d(16, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
bn2 = nn.BatchNorm3d(16)

In [62]:
r2plus1d = nn.Sequential(spatial_conv, bn1, temp_conv, bn2)
r2plus1d

Sequential(
  (0): Conv3d(3, 16, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): Conv3d(16, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
  (3): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [63]:
r2plus1d(x).shape

torch.Size([16, 16, 16, 112, 112])