In [21]:
import torch
import torch.nn as nn

In [6]:
x = torch.randn(10, 3)
W = torch.randn(3, 5)
y = x @ W
print(y.shape)

torch.Size([10, 5])


In [8]:
input = torch.randn(128, 3)
linear1 = nn.Linear(3, 5, bias=True, device='cpu')
output = linear1(input)
print(output.shape)

torch.Size([128, 5])


In [16]:
print(linear1.weight.dtype, linear1.weight.device)
linear1.to('cuda')
w3 = linear1.weight.bfloat16()
print(w3.dtype, linear1.weight.device)

torch.float32 cuda:0
torch.bfloat16 cuda:0


In [17]:
nn.Conv2d(3, 5, 3, 1, 1)
img = torch.randn(128, 3, 32, 32)
o = nn.Conv2d(3, 5, 3, 1, 1)(img)
print(o.shape)

torch.Size([128, 5, 32, 32])


In [18]:
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision import transforms

train = datasets.FashionMNIST(
    root='data', train=True, download=True, 
    transform=transforms.ToTensor())

test = datasets.FashionMNIST(
    root='data', train=False, download=True, 
    transform=transforms.ToTensor())

train_loader = DataLoader(train, batch_size=32, shuffle=True)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:09<00:00, 2758967.13it/s]


Extracting data\FashionMNIST\raw\train-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 108348.31it/s]


Extracting data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:03<00:00, 1135028.76it/s]


Extracting data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 4857655.12it/s]

Extracting data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw






In [19]:
train_iter = iter(train_loader)
images, labels = next(train_iter)
images.shape, labels.shape 

(torch.Size([32, 1, 28, 28]), torch.Size([32]))

In [23]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16*4*4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
lenet = LeNet()
output = lenet(images)
print(output.shape)

torch.Size([32, 10])


In [24]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lenet.parameters(), lr=0.001)

for epoch in range(10):
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = lenet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
            
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = lenet(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print(f'Accuracy: {100 * correct / total}')

Epoch 0, Loss: 2.3150200843811035
Epoch 0, Loss: 1.0898656845092773
Epoch 0, Loss: 0.8683124780654907
Epoch 0, Loss: 0.8185096383094788
Epoch 0, Loss: 0.6437197327613831
Epoch 0, Loss: 0.779932975769043
Epoch 0, Loss: 0.48604175448417664
Epoch 0, Loss: 0.7252881526947021
Epoch 0, Loss: 0.4339439570903778
Epoch 0, Loss: 0.35256826877593994
Epoch 0, Loss: 0.49526628851890564
Epoch 0, Loss: 0.3143845200538635
Epoch 0, Loss: 0.42572683095932007
Epoch 0, Loss: 0.3852396011352539
Epoch 0, Loss: 0.5696452856063843
Epoch 0, Loss: 0.4369112551212311
Epoch 0, Loss: 0.7685883641242981
Epoch 0, Loss: 0.31939011812210083
Epoch 0, Loss: 0.7180009484291077
Epoch 1, Loss: 0.5560163259506226
Epoch 1, Loss: 0.4160979688167572
Epoch 1, Loss: 0.5232569575309753
Epoch 1, Loss: 0.5342051982879639
Epoch 1, Loss: 0.28193604946136475
Epoch 1, Loss: 0.3335415720939636
Epoch 1, Loss: 0.3790576159954071
Epoch 1, Loss: 0.4737567603588104
Epoch 1, Loss: 0.3045908808708191
Epoch 1, Loss: 0.35865819454193115
Epoch 1,