In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x11a002c88>

In [3]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [5]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t) 
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 12*4*4)

        t = self.fc1(t)
        t = F.relu(t)

        t = self.fc2(t)
        t = F.relu(t)

        t = self.out(t)

        return t

In [8]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01) #Network Weights are in here

batch = next(iter(train_loader)) # Get Batch
images, labels = batch

preds = network(images) # Pass Batch
loss = F.cross_entropy(preds, labels) # Calculate Loss

loss.backward() # Calculate Gradients
optimizer.step() # Update Weights

print('loss1:', loss.item())
preds = network(images)
loss = F.cross_entropy(preds, labels)
print('loss1:', loss.item())

loss1: 2.3079214096069336
loss1: 2.2951629161834717


In [9]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01) #Network Weights are in here

total_loss = 0
total_correct = 0

# batch = next(iter(train_loader)) # Get Batch
for batch in train_loader:
    images, labels = batch

    preds = network(images) # Pass Batch
    loss = F.cross_entropy(preds, labels) # Calculate Loss

    optimizer.zero_grad() # PyTorch Adds to this, so you have to zero it out after one batch
    loss.backward() # Calculate Gradients
    optimizer.step() # Update Weights

    total_loss += loss.item()
    total_correct += get_num_correct(preds, labels)

print("epoch:", 0, "total_correct:", total_correct, "loss:", total_loss)


epoch: 0 total_correct: 47212 loss: 334.7672748565674


In [10]:
total_correct / len(train_set)

0.7868666666666667

In [6]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01) #Network Weights are in here

for epoch in range(5):
    total_loss = 0
    total_correct = 0

    # batch = next(iter(train_loader)) # Get Batch
    for batch in train_loader:
        images, labels = batch

        preds = network(images) # Pass Batch
        loss = F.cross_entropy(preds, labels) # Calculate Loss

        optimizer.zero_grad() # PyTorch Adds to this, so you have to zero it out after one batch
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)

    print("epoch:", epoch, "total_correct:", total_correct, "loss:", total_loss)

epoch: 0 total_correct: 46180 loss: 364.37743493914604
epoch: 1 total_correct: 50833 loss: 246.43286822736263
epoch: 2 total_correct: 51676 loss: 223.522591650486
epoch: 3 total_correct: 52130 loss: 212.12095564603806
epoch: 4 total_correct: 52296 loss: 206.46559286117554


In [7]:
total_correct / len(train_set)

0.8716

In [14]:
for name, param in network.named_parameters():
    print(name, '\t', param)

e-01],
          [-4.7825e-01, -3.8159e-01, -4.4542e-01, -2.5959e-01, -4.0990e-01],
          [-3.9337e-01, -4.0768e-01, -3.9919e-01, -3.0776e-01, -3.7316e-01],
          [-5.1077e-01, -4.6276e-01, -3.7039e-01, -2.3538e-01, -3.3991e-01],
          [-3.5820e-01, -3.6214e-01, -4.8788e-01, -3.4681e-01, -4.3302e-01]],

         [[ 1.5422e-01, -4.1945e-01, -3.6128e-01, -6.0054e-02, -1.6455e-02],
          [ 1.3351e-01, -4.8144e-01, -1.6687e-01, -1.4987e-01,  1.3087e-01],
          [-2.0308e-01, -3.0870e-01, -5.2527e-01, -1.3949e-01,  1.6658e-01],
          [-1.6298e-01, -3.3641e-01, -2.0040e-01, -4.5271e-02,  1.3012e-01],
          [-5.0621e-01, -2.6261e-01, -2.8340e-01, -2.6641e-02,  3.2039e-01]],

         [[-4.2162e-01, -2.4276e-01, -5.2980e-01, -4.0144e-01, -2.6757e-01],
          [-7.9771e-02, -1.4449e-01, -2.5188e-01, -2.7640e-01, -3.0085e-01],
          [-1.1547e-01, -1.4793e-01, -5.3237e-01, -3.4014e-01, -1.1607e-01],
          [-1.2901e-01, -4.1581e-01, -2.3483e-01, -3.3300e-01, -3