In [1]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt
import random

In [16]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((572,572)),
    torchvision.transforms.ToTensor()
])
raw_training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

raw_test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}


In [17]:
class ContractBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2)
        )

    def forward(self, x):
        return self.block(x)

class ExpandBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.ReLU()
        )

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.block(x)

class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = ContractBlock(1, 64)
        self.enc2 = ContractBlock(64, 128)
        self.enc3 = ContractBlock(128, 256)
        self.enc4 = ContractBlock(256, 512)

        self.middle = torch.nn.Sequential(
            torch.nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            torch.nn.ReLU()
        )

        self.dec4 = ExpandBlock(1024, 512)
        self.dec3 = ExpandBlock(512, 256)
        self.dec2 = ExpandBlock(256, 128)
        self.dec1 = ExpandBlock(128, 64)

        self.final = torch.nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        s1 = self.enc1(x)
        s2 = self.enc2(s1)
        s3 = self.enc3(s2)
        s4 = self.enc4(s3)

        b = self.middle(s4)

        d4 = self.dec4(b, s4)
        d3 = self.dec3(d4, s3)
        d2 = self.dec2(d3, s2)
        d1 = self.dec1(d2, s1)

        return self.final(d1)

# Example instantiation
net = UNet()
net.parameters()

<generator object Module.parameters at 0x0000013645DBCBA0>

In [None]:
dataloader = torch.utils.data.DataLoader(raw_training_data)
optimizer = torch.optim.Adam(net.parameters())
loss_fn = torch.nn.Softmax()
for epoch in range(200):
    for img, label in dataloader:
        optimizer.zero_grad()
        pred = net(img)
        loss=  loss_fn(pred,5)
        #just realized unet is a binary classifier