In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor 
import torchvision.transforms as T
import torch.nn as nn

import numpy as np

In [2]:
traindt = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)

In [3]:
transform = T.Compose([T.Resize(10), T.ToTensor(),])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)

class MulticlassLogisticRegression(nn.Module):
    def __init__(self):
        super(MulticlassLogisticRegression, self).__init__()
        self.linear = nn.Linear(100, 10, bias=False) 

    def forward(self, x):
        x = x.view(-1, 10*10)
        x = self.linear(x)
        return x

model = MulticlassLogisticRegression()
print(model)

MulticlassLogisticRegression(
  (linear): Linear(in_features=100, out_features=10, bias=False)
)


In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


for epoch in range(50):
    sum_losses = 0
    for x, (images, labels) in enumerate(train_loader): 
        images = images.reshape(-1, 10*10)
        
        out = model(images)
        losses = criterion(out, labels)
        optimizer.zero_grad()
        losses.backward()
        optimizer.step() 
        sum_losses += losses.item()
        if x == 10:
            break
        
    print(f"EPOCH {epoch+1} TRAIN LOSS : {sum_losses / len(train_loader)}")

EPOCH 1 TRAIN LOSS : 0.049429771234231715
EPOCH 2 TRAIN LOSS : 0.040949943985766186
EPOCH 3 TRAIN LOSS : 0.03474326631916103
EPOCH 4 TRAIN LOSS : 0.03010570342098472
EPOCH 5 TRAIN LOSS : 0.025859049388340542
EPOCH 6 TRAIN LOSS : 0.023904474432280325
EPOCH 7 TRAIN LOSS : 0.02236933634479417
EPOCH 8 TRAIN LOSS : 0.020130833710180415
EPOCH 9 TRAIN LOSS : 0.01887476634877577
EPOCH 10 TRAIN LOSS : 0.018072802374866217
EPOCH 11 TRAIN LOSS : 0.01689650064338245
EPOCH 12 TRAIN LOSS : 0.016210855832740442
EPOCH 13 TRAIN LOSS : 0.015215442378892064
EPOCH 14 TRAIN LOSS : 0.014970771031084854
EPOCH 15 TRAIN LOSS : 0.01565117495400565
EPOCH 16 TRAIN LOSS : 0.013801419404523967
EPOCH 17 TRAIN LOSS : 0.013834411592117504
EPOCH 18 TRAIN LOSS : 0.013720992277425998
EPOCH 19 TRAIN LOSS : 0.013430981303074721
EPOCH 20 TRAIN LOSS : 0.012443502828764764
EPOCH 21 TRAIN LOSS : 0.013017297998420212
EPOCH 22 TRAIN LOSS : 0.012166319053564498
EPOCH 23 TRAIN LOSS : 0.012567679765127869
EPOCH 24 TRAIN LOSS : 0.01

In [6]:
model.eval()

weights = model.linear.weight.detach().cpu().numpy()

In [6]:
weights.shape

(10, 100)

In [7]:
for i, (imgs, labels) in enumerate(train_loader):
    
    outputs = nn.Softmax(dim=1)(model(imgs))
    imgs = imgs.detach().cpu().numpy()
    
    imgs2 = imgs.reshape((128, 100))
    
    w = (imgs2[:, np.newaxis, :] * weights[np.newaxis, :, :]).transpose(2, 1, 0)
    w = np.maximum(w, 0)
    w = np.minimum(w, 1)
    outputs = outputs.detach().cpu().numpy()
    imgs = np.squeeze(imgs, axis=1).transpose(1, 2, 0)
    outputs = np.expand_dims(outputs, axis=1).transpose(1, 2, 0)
    np.save("inputs.npy", imgs)
    np.save("outputs.npy", outputs)
    np.save("weights.npy", w)
    break