In [None]:
import torch
from torch import nn, optim, Tensor
from datasets import load_dataset, load_dataset_builder
from torch.utils.data import DataLoader, default_collate, Dataset
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
ds = load_dataset("segments/sidewalk-semantic",data_dir="./data")

In [None]:
dataset = ds["train"]
print(len(dataset))

In [None]:
x,y = dataset.features
num_classes = 35

In [None]:
train_split = 0.8

train_ds, valid_ds = torch.utils.data.random_split(dataset,[int(train_split*len(dataset)), len(dataset)- int(train_split*len(dataset))])

In [None]:
print(len(train_ds), len(valid_ds))

In [None]:
train_ds[0:4]

In [None]:
class myDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        if isinstance(i, int):
            return TF.to_tensor(self.dataset[i][x]) , TF.to_tensor(self.dataset[i][y])
        images = self.dataset[i][x]
        images = [TF.to_tensor(o) for o in images]
        images = torch.stack(images)
        labels = self.dataset[i][y]
        labels = [TF.to_tensor(o) for o in labels]
        labels = torch.stack(labels)
        return images, labels

In [None]:
train_ds = myDataset(train_ds)
valid_ds = myDataset(valid_ds)

BATCH_SIZE = 4

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Let's see what the data looks like

batch = next(iter(train_dl))
images, labels = batch

print(images.shape, labels.shape)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(images[0].permute(1,2,0))
ax[1].imshow(labels[0].permute(1,2,0))
plt.show()

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        # self.down4 = Down(512, 1024 // factor)
        # self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up2(x4, x3)
        x = self.up3(x3, x2)
        x = self.up4(x2, x1)
        logits = self.outc(x)
        return logits

In [None]:
model = UNet(3, num_classes)

In [None]:
batch = next(iter(train_dl))
images, labels = batch
preds = model(images)
print(preds.shape)

In [None]:
lr = 1e-3
epochs = 5
opt = optim.Adam(model.parameters(), lr=lr, eps = 1e-5)
sched = optim.lr_scheduler.OneCycleLR(opt, lr, epochs=epochs, steps_per_epoch=len(train_dl))
criterion = nn.CrossEntropyLoss()

In [None]:
def fit(model, epochs, opt, sched, criterion, train_dl, valid_dl):
    for epoch in range(epochs):
        tot_loss,tot_acc,count = 0.,0.,0
        model.train()
        for images, labels in tqdm(train_dl, total=len(train_dl)):
            labels = labels.squeeze(1)
            preds = model(images)
            print(preds.shape, labels.shape)
            loss = criterion(preds, labels.long())
            n = len(images)
            count += n
            tot_loss += loss.item()*n
            tot_acc  += (preds.argmax(dim=1)==labels).float().mean().item()*n
            loss.backward()
            opt.step()
            opt.zero_grad()
            sched.step()
        print(f"Epoch {epoch} - Training Loss: {tot_loss/count} - Training Accuracy: {tot_acc/count}")
        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for images, labels in tqdm(valid_dl, total=len(valid_dl)):
                preds = model(images)
                n = len(images)
                count += n
                tot_loss += criterion(preds,labels).item()*n
                tot_acc  += (preds.argmax(dim=1)==labels).float().mean().item()*n
        print(f"Epoch {epoch} - Validation Loss: {tot_loss/count} - Validation Accuracy: {tot_acc/count}")

In [None]:
fit(model, epochs, opt, sched, criterion, train_dl, valid_dl)