In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor 
import torchvision.transforms as T
import torch.nn as nn

import numpy as np

In [2]:
traindt = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)

In [3]:
transform = T.Compose([T.Resize(10), T.ToTensor(),])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)

class MulticlassLogisticRegression(nn.Module):
    def __init__(self):
        super(MulticlassLogisticRegression, self).__init__()
        self.linear = nn.Linear(100, 10, bias=False) 

    def forward(self, x):
        x = x.view(-1, 10*10)
        x = self.linear(x)
        return x

model = MulticlassLogisticRegression()
print(model)

MulticlassLogisticRegression(
  (linear): Linear(in_features=100, out_features=10, bias=False)
)


In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.to(device)
for epoch in range(50):
    sum_losses = 0
    for x, (images, labels) in enumerate(train_loader): 
        images = images.reshape(-1, 10*10)
        images = images.to(device)
        labels = labels.to(device)
        out = model(images)
        losses = criterion(out, labels)
        optimizer.zero_grad()
        losses.backward()
        optimizer.step() 
        sum_losses += losses.item()
        if x == 10:
            break
        
    print(f"EPOCH {epoch+1} TRAIN LOSS : {sum_losses / len(train_loader)}")

EPOCH 1 TRAIN LOSS : 0.04927086296366222
EPOCH 2 TRAIN LOSS : 0.040947441353218386
EPOCH 3 TRAIN LOSS : 0.034803865052489585
EPOCH 4 TRAIN LOSS : 0.029897562476363517
EPOCH 5 TRAIN LOSS : 0.02601870621191159
EPOCH 6 TRAIN LOSS : 0.02355194066379116
EPOCH 7 TRAIN LOSS : 0.021956349740912918
EPOCH 8 TRAIN LOSS : 0.020437889388883548
EPOCH 9 TRAIN LOSS : 0.01926871797423373
EPOCH 10 TRAIN LOSS : 0.018611939858271878
EPOCH 11 TRAIN LOSS : 0.017569879478991413
EPOCH 12 TRAIN LOSS : 0.017486985177119403
EPOCH 13 TRAIN LOSS : 0.015547546496523469
EPOCH 14 TRAIN LOSS : 0.015089636164179234
EPOCH 15 TRAIN LOSS : 0.014797103430416538
EPOCH 16 TRAIN LOSS : 0.014493622759511984
EPOCH 17 TRAIN LOSS : 0.01442087700626235
EPOCH 18 TRAIN LOSS : 0.013671220556251022
EPOCH 19 TRAIN LOSS : 0.013384997209252071
EPOCH 20 TRAIN LOSS : 0.013659777671797697
EPOCH 21 TRAIN LOSS : 0.012935150279673431
EPOCH 22 TRAIN LOSS : 0.012510468202359133
EPOCH 23 TRAIN LOSS : 0.011871099789767886
EPOCH 24 TRAIN LOSS : 0.0

In [5]:
model.eval()

weights = model.linear.weight.detach().cpu().numpy()

In [6]:
weights.shape

(10, 100)

In [16]:
for i, (imgs, labels) in enumerate(train_loader):
    imgs = imgs.to(device)
    outputs = nn.Softmax(dim=1)(model(imgs))
    imgs = imgs.detach().cpu().numpy()
    
    imgs2 = imgs.reshape((128, 100))
    
    w = (imgs2[:, np.newaxis, :] * weights[np.newaxis, :, :]).transpose(2, 1, 0)
    w = np.maximum(w, 0)
    w = np.minimum(w, 1)
    outputs = outputs.detach().cpu().numpy()
    imgs = np.squeeze(imgs, axis=1).transpose(1, 2, 0)
    outputs = np.expand_dims(outputs, axis=1).transpose(1, 2, 0)
    np.save("inputs.npy", imgs)
    np.save("outputs.npy", outputs)
    np.save("weights.npy", w)
    break

(128, 1, 10, 10)
(128, 100)
(100, 10, 128)
