In [1]:
# Import pacakges

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
import itertools
import matplotlib.pyplot as plt
from functools import partial

In [2]:
# Set the random seeds
torch.manual_seed(0)
np.random.seed(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
print("device:", device)

device: cuda:4


In [12]:
BATCH_SIZE = 100 # Batch size
# Load the dataset
train_set = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                       ]))
train_loader = torch.utils.data.DataLoader(train_set, BATCH_SIZE)

test_set = datasets.MNIST('./data', train=False, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ]))

test_loader = torch.utils.data.DataLoader(test_set, BATCH_SIZE)

In [24]:
for data in train_loader:
    print(data[0].shape)
    print(data[1].shape)
    break

torch.Size([100, 1, 28, 28])
torch.Size([100])


In [10]:
class SimpleClassification(nn.Module):

    def __init__(self, input_dim = 28*28, out_dim = 10):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, out_dim, bias=False)


    def forward(self,x):
        return self.linear1(x)

In [15]:
test = SimpleClassification()
sum(param.numel() for param in test.parameters())

7840

In [14]:
model = SimpleClassification()

AE_EPOCHS = 10 # Epochs for training the autoencoder
# We use a binary cross-entropy loss for the reconstruction error
loss_f = nn.CrossEntropyLoss()

# Build the autoencoder
model.to(device)

optimizer = torch.optim.SGD(itertools.chain(model.parameters()),
                             lr=1e-3)

for i in range(AE_EPOCHS):
    print('Epoch #{}'.format(i+1))

    losses = []
    for batch_idx, data in enumerate(train_loader):

        x, y = data
        x = x.to(device) # [B, 1, H, W]
        y = y.to(device)

        x = x.squeeze(1).reshape(x.size(0), -1)

        # Run the autoencoder
        out = model(x)
        loss = loss_f(out, y)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print(f"E {i}, {batch_idx:.3f}/{len(train_loader)}: Loss: {loss.item()}")

            ## Evaluate
            total_acc = 0
            for data in test_loader:
                x, y = data
                x, y = x.to(device), y.to(device)
                x = x.squeeze(1).reshape(x.size(0), -1)
                out = model(x)
                labels = torch.argmax(out, dim = 1) # [B]
                acc = sum(labels == y) / labels.numel()
                total_acc +=acc
            print(f"evaluation acc: {total_acc / len(test_loader)}")

Epoch #1
E 0, 0.000/600: Loss: 2.3616983890533447
evaluation acc: 0.09650001674890518
E 0, 200.000/600: Loss: 2.1203904151916504
evaluation acc: 0.4345000088214874
E 0, 400.000/600: Loss: 1.9179693460464478
evaluation acc: 0.642300009727478
Epoch #2
E 1, 0.000/600: Loss: 1.7856677770614624
evaluation acc: 0.7069999575614929
E 1, 200.000/600: Loss: 1.735583782196045
evaluation acc: 0.7375999093055725
E 1, 400.000/600: Loss: 1.519203782081604
evaluation acc: 0.7594999670982361
Epoch #3
E 2, 0.000/600: Loss: 1.4406170845031738
evaluation acc: 0.7713997960090637
E 2, 200.000/600: Loss: 1.4831302165985107
evaluation acc: 0.7834998369216919
E 2, 400.000/600: Loss: 1.2670003175735474
evaluation acc: 0.7970998287200928
Epoch #4
E 3, 0.000/600: Loss: 1.2196794748306274
evaluation acc: 0.8043997287750244
E 3, 200.000/600: Loss: 1.313604474067688
evaluation acc: 0.80899977684021
E 3, 400.000/600: Loss: 1.100113034248352
evaluation acc: 0.8158999681472778
Epoch #5
E 4, 0.000/600: Loss: 1.070811033

In [34]:
import torch.nn.functional as F
class SimpleScratch():

    def __init__(self, input_dim = 28*28, output_dim = 10):
        self.W = torch.randn([input_dim, output_dim])
        self.n_class = output_dim

        pass
    
    
    def _get_grad(self, x, y):
        ## x: [B, M]
        ## y: [B]
        y_one_hot = F.one_hot(y, self.n_class).float()# [B, K]

        out = self.predict(x) # [B, K]
        soft = F.softmax(out, dim = -1) # [B,K]

        grad_out = -y_one_hot + soft # [B, K]
        grad_W = torch.bmm(x.unsqueeze(-1), grad_out.unsqueeze(1)) #[B,M,K]
        grad_W = grad_W.mean(dim =0)
        return grad_W
    
    def train(self, x, y, lr=0.001):
        grad_W = self._get_grad(x, y)
        self.W = self.W - lr * grad_W


    def predict(self, x:torch.Tensor):
        ## x: [B, M]
        W = self.W.unsqueeze(0).repeat(x.size(0),1, 1) # [B, M, K]
        out = torch.bmm(W.transpose(1,2), x.unsqueeze(-1)) # [B, K, 1]
        return out.squeeze(-1) # [B, K]

In [35]:
model = SimpleScratch()
for data in train_loader:
    x = data[0]
    y = data[1]
    break
x = x.squeeze(1).reshape(x.size(0),-1)
print("x", x.shape)
print("y", y.shape)
grad = model._get_grad(x, y)
print(grad.shape)

x torch.Size([100, 784])
y torch.Size([100])
torch.Size([784, 10])


In [37]:
model = SimpleScratch()

LR = 1.0e-2
AE_EPOCHS = 10 # Epochs for training the autoencoder
# We use a binary cross-entropy loss for the reconstruction error
for i in range(AE_EPOCHS):
    print('Epoch #{}'.format(i+1))

    losses = []
    for batch_idx, data in enumerate(train_loader):

        x, y = data

        x = x.squeeze(1).reshape(x.size(0), -1)

        # Run the autoencoder
        model.train(x, y, lr = LR)

        if batch_idx % 200 == 0:
            print(f"E {i}, {batch_idx:.3f}/{len(train_loader)}: Loss: {loss.item()}")

            ## Evaluate
            total_acc = 0
            for data in test_loader:
                x, y = data
                x = x.squeeze(1).reshape(x.size(0), -1)
                out = model.predict(x)
                labels = torch.argmax(out, dim = 1) # [B]
                acc = sum(labels == y) / labels.numel()
                total_acc +=acc
            print(f"evaluation acc: {total_acc / len(test_loader)}")

Epoch #1
E 0, 0.000/600: Loss: 0.8701011538505554
evaluation acc: 0.11789999902248383
E 0, 200.000/600: Loss: 0.8701011538505554
evaluation acc: 0.20009993016719818
E 0, 400.000/600: Loss: 0.8701011538505554
evaluation acc: 0.3082999587059021
Epoch #2
E 1, 0.000/600: Loss: 0.8701011538505554
evaluation acc: 0.38159993290901184
E 1, 200.000/600: Loss: 0.8701011538505554
evaluation acc: 0.43469998240470886
E 1, 400.000/600: Loss: 0.8701011538505554
evaluation acc: 0.48019999265670776
Epoch #3
E 2, 0.000/600: Loss: 0.8701011538505554
evaluation acc: 0.5152999758720398
E 2, 200.000/600: Loss: 0.8701011538505554
evaluation acc: 0.544499933719635
E 2, 400.000/600: Loss: 0.8701011538505554
evaluation acc: 0.5703999400138855
Epoch #4
E 3, 0.000/600: Loss: 0.8701011538505554
evaluation acc: 0.5906999111175537
E 3, 200.000/600: Loss: 0.8701011538505554
evaluation acc: 0.6130999326705933
E 3, 400.000/600: Loss: 0.8701011538505554
evaluation acc: 0.630099892616272
Epoch #5
E 4, 0.000/600: Loss: 0.

In [22]:
x = torch.ones(1, 2, 3)
x[0][0] = 20
out = x.repeat(5, 1, 1)
print(out)

tensor([[[20., 20., 20.],
         [ 1.,  1.,  1.]],

        [[20., 20., 20.],
         [ 1.,  1.,  1.]],

        [[20., 20., 20.],
         [ 1.,  1.,  1.]],

        [[20., 20., 20.],
         [ 1.,  1.,  1.]],

        [[20., 20., 20.],
         [ 1.,  1.,  1.]]])
