In [1]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [2]:
trainData = datasets.FashionMNIST("F_MNIST_data", train=True, download=True, transform=ToTensor())

trainLoader = DataLoader(dataset=trainData, batch_size = 10, shuffle = True, pin_memory=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to F_MNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 17995245.33it/s]


Extracting F_MNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz to F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to F_MNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 262441.32it/s]


Extracting F_MNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to F_MNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 4963034.76it/s]


Extracting F_MNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to F_MNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to F_MNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 9863991.32it/s]

Extracting F_MNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to F_MNIST_data/FashionMNIST/raw






In [3]:
class FashionMNIST_Part1(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 20000),
            nn.ReLU(),
            nn.Linear(20000, 10000),
            nn.ReLU(),
            nn.Linear(10000, 5000),
            nn.ReLU(),
        )

    def forward(self,x):
        return self.layers(x)

class FashionMNIST_Part2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(5000, 1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500,10)
        )
    def forward(self,x):
        return self.layers(x)


In [4]:
torch.cuda.device_count()

2

In [5]:
device1 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device2 = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model_part1 = FashionMNIST_Part1().to(device1)
model_part2 = FashionMNIST_Part2().to(device2)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(list(model_part1.parameters())+list(model_part2.parameters()) ,lr=0.01 ,momentum=0.5)

In [6]:
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    running_corrects = 0.0
    for inputs, labels in trainLoader:

        inputs = inputs.to(device1)
        labels = labels.to(device2)

        optimizer.zero_grad()

        intermediates = model_part1(inputs).to(device2)

        outputs = model_part2(intermediates)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(trainLoader.dataset)
    epoch_acc = running_corrects.double() / len(trainLoader.dataset)
    print('Epoch [{}/{}], Loss: {:.4f}, Acc: {:.4f}'.format(epoch+1, epochs, epoch_loss, epoch_acc))

Epoch [1/10], Loss: 0.6561, Acc: 0.7601
Epoch [2/10], Loss: 0.3683, Acc: 0.8662
Epoch [3/10], Loss: 0.3142, Acc: 0.8822
Epoch [4/10], Loss: 0.2788, Acc: 0.8966
Epoch [5/10], Loss: 0.2502, Acc: 0.9063
Epoch [6/10], Loss: 0.2276, Acc: 0.9144
Epoch [7/10], Loss: 0.2086, Acc: 0.9206
Epoch [8/10], Loss: 0.1913, Acc: 0.9285
Epoch [9/10], Loss: 0.1749, Acc: 0.9343
Epoch [10/10], Loss: 0.1612, Acc: 0.9394
