In [1]:
#!g1.1
import torch
import matplotlib.pyplot as plt

In [3]:
#!g1.1
import torchvision.transforms as T
from torchvision.datasets import OxfordIIITPet

dataset = OxfordIIITPet('Deep_learning\6_object_detection\PETSdataset', target_types='segmentation', download=True)

OSError: [WinError 123] The filename, directory name, or volume label syntax is incorrect: 'Deep_learning\x06_object_detection'

In [3]:
#!g1.1
transform = T.Compose(
    [
        T.Resize((256, 256)),
        T.ToTensor(),
    ]
)

target_transform = T.Compose(
    [
        T.Resize((256, 256)),
        T.PILToTensor(),
        T.Lambda(lambda x: (x - 1).long())
    ]
)

train_dataset = OxfordIIITPet('../datasets/OxfordIIITPet', transform=transform, target_transform=target_transform, target_types='segmentation')
valid_dataset = OxfordIIITPet('../datasets/OxfordIIITPet', transform=transform, split='test', target_transform=target_transform, target_types='segmentation')

In [4]:
#!g1.1
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)



In [5]:
#!g1.1
from tqdm import tqdm


def train(model) -> float:
    model.train()

    train_loss = 0
    total = 0
    correct = 0

    for x, y in tqdm(train_loader, desc='Train'):
        bs = y.size(0)

        x, y = x.to(device), y.squeeze(1).to(device)

        optimizer.zero_grad()

        output = model(x)

        loss = loss_fn(output.reshape(bs, 3, -1), y.reshape(bs, -1))

        train_loss += loss.item()

        loss.backward()

        optimizer.step()

        _, y_pred = output.max(dim=1)
        total += y.size(0) * y.size(1) * y.size(2)
        correct += (y == y_pred).sum().item()

    train_loss /= len(train_loader)
    accuracy = correct / total

    return train_loss, accuracy

In [6]:
#!g1.1
@torch.inference_mode()
def evaluate(model, loader) -> tuple[float, float]:
    model.eval()

    total_loss = 0
    total = 0
    correct = 0

    for x, y in tqdm(loader, desc='Evaluation'):
        bs = y.size(0)

        x, y = x.to(device), y.squeeze(1).to(device)

        output = model(x)

        loss = loss_fn(output.reshape(bs, 3, -1), y.reshape(bs, -1))

        total_loss += loss.item()

        _, y_pred = output.max(dim=1)
        total += y.size(0) * y.size(1) * y.size(2)
        correct += (y == y_pred).sum().item()

    total_loss /= len(loader)
    accuracy = correct / total

    return total_loss, accuracy

In [7]:
#!g1.1
import numpy as np
from PIL import Image

In [8]:
import torch.nn as nn


def conv_plus_conv(in_channels: int, out_channels: int):
    """
    Makes UNet block
    :param in_channels: input channels
    :param out_channels: output channels
    :return: UNet block
    """
    return nn.Sequential(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.BatchNorm2d(num_features=out_channels),
        nn.LeakyReLU(0.2),
        nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.BatchNorm2d(num_features=out_channels),
        nn.LeakyReLU(0.2),
    )


class UNET(nn.Module):
    def __init__(self):
        super().__init__()

        base_channels = 16

        self.down1 = conv_plus_conv(3, base_channels)
        self.down2 = conv_plus_conv(base_channels, base_channels * 2)
        self.down3 = conv_plus_conv(base_channels * 2, base_channels * 4)
        self.down4 = conv_plus_conv(base_channels * 4, base_channels * 8)

        self.up1 = conv_plus_conv(base_channels * 2, base_channels)
        self.up2 = conv_plus_conv(base_channels * 4, base_channels)
        self.up3 = conv_plus_conv(base_channels * 8, base_channels * 2)
        self.up4 = conv_plus_conv(base_channels * 16, base_channels * 4)

        self.bottleneck = conv_plus_conv(base_channels * 8, base_channels * 8)

        self.out = nn.Conv2d(in_channels=base_channels, out_channels=3, kernel_size=1)

        self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # x.shape = (N, N, 3)

        residual1 = self.down1(x)  # x.shape: (N, N, 3) -> (N, N, base_channels)
        x = self.downsample(residual1)  # x.shape: (N, N, base_channels) -> (N // 2, N // 2, base_channels)

        residual2 = self.down2(x)  # x.shape: (N // 2, N // 2, base_channels) -> (N // 2, N // 2, base_channels * 2)
        x = self.downsample(residual2)  # x.shape: (N // 2, N // 2, base_channels * 2) -> (N // 4, N // 4, base_channels * 2)

        residual3 = self.down3(x)  # x.shape: (N // 4, N // 4, base_channels * 2) -> (N // 4, N // 4, base_channels * 4)
        x = self.downsample(residual3)  # x.shape: (N // 4, N // 4, base_channels * 4) -> (N // 8, N // 8, base_channels * 4)

        residual4 = self.down4(x)  # x.shape: (N // 8, N // 8, base_channels * 4) -> (N // 8, N // 8, base_channels * 8)
        x = self.downsample(residual4)  # x.shape: (N // 8, N // 8, base_channels * 8) -> (N // 16, N // 16, base_channels * 8)
        
        # LATENT SPACE DIMENSION DIM = N // 16
        # SOME MANIPULATION MAYBE
        x = self.bottleneck(x)  # x.shape: (N // 16, N // 16, base_channels * 8) -> (N // 16, N // 16, base_channels * 8)
        # SOME MANIPULATION MAYBE
        # LATENT SPACE DIMENSION DIM = N // 16

        x = nn.functional.interpolate(x, scale_factor=2)  # x.shape: (N // 16, N // 16, base_channels * 8) -> (N // 8, N // 8, base_channels * 8)
        x = torch.cat((x, residual4), dim=1)  # x.shape: (N // 8, N // 8, base_channels * 8) -> ((N // 8, N // 8, base_channels * 16)
        x = self.up4(x)  # x.shape: (N // 8, N // 8, base_channels * 16) -> (N // 8, N // 8, base_channels * 4)
        
        x = nn.functional.interpolate(x, scale_factor=2)  # x.shape: (N // 8, N // 8, base_channels * 4) -> (N // 4, N // 4, base_channels * 4)
        x = torch.cat((x, residual3), dim=1)  # x.shape: (N // 4, N // 4, base_channels * 4) -> ((N // 4, N // 4, base_channels * 8)
        x = self.up3(x)  # x.shape: (N // 4, N // 4, base_channels * 8) -> (N // 4, N // 4, base_channels * 2)
        
        x = nn.functional.interpolate(x, scale_factor=2)  # x.shape: (N // 4, N // 4, base_channels * 2) -> (N // 2, N // 2, base_channels * 2)
        x = torch.cat((x, residual2), dim=1)  # x.shape: (N // 2, N // 2, base_channels * 2) -> (N // 2, N // 2, base_channels * 4)
        x = self.up2(x)  # x.shape: (N // 2, N // 2, base_channels * 4) -> (N // 2, N // 2, base_channels)

        x = nn.functional.interpolate(x, scale_factor=2)  # x.shape: (N // 2, N // 2, base_channels) -> (N, N, base_channels)
        x = torch.cat((x, residual1), dim=1)  # x.shape: (N, N, base_channels) -> (N, N, base_channels * 2)
        x = self.up1(x)  # x.shape: (N, N, base_channels * 2) -> (N, N, base_channels)

        x = self.out(x)  # x.shape: (N, N, base_channels) -> (N, N, 3)

        return x

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

model = UNET().to(device)

from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=1e-3)
#scheduler = StepLR(optimizer, step_size=25)

loss_fn = nn.CrossEntropyLoss()

cuda:0


In [None]:
train_loss_history, valid_loss_history = [], []
train_accuracy_history, valid_accuracy_history = [], []

num_epochs = 100

best_valid_accuracy = 0

for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model)
    valid_loss, valid_accuracy = evaluate(model, valid_loader)

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_accuracy_history.append(train_accuracy)
    valid_accuracy_history.append(valid_accuracy)
    
    best_valid_accuracy = max(valid_accuracy, best_valid_accuracy)
    
    print(f'epoch = {epoch+1} with valid_accuracy = {valid_accuracy*100}')
    print(f'epoch = {epoch+1} with best_valid_accuracy = {best_valid_accuracy*100}')
    
    if valid_accuracy >= 0.885:
        #torch.save(model.state_dict(), 'weights.pt')
        #preds = predict_tta(model, valid_loader_aug, device, iterations=20)  
        #torch.save(preds, 'preds')
        #torch.save(model, 'model')
        break
    
    #if best_valid_accuracy >= 0.9:
    #    preds_best = predict_tta(model, valid_loader_aug, device, iterations=20)  
    #    torch.save(preds_best, 'preds_best')
    #    torch.save(model, 'best_model')
    
    #scheduler.step()

Train: 100%|██████████| 58/58 [00:49<00:00,  1.17it/s]
Evaluation: 100%|██████████| 58/58 [00:32<00:00,  1.76it/s]


epoch = 1 with valid_accuracy = 73.8469262901591
epoch = 1 with best_valid_accuracy = 73.8469262901591


Train: 100%|██████████| 58/58 [00:38<00:00,  1.51it/s]
Evaluation: 100%|██████████| 58/58 [00:30<00:00,  1.93it/s]


epoch = 2 with valid_accuracy = 78.23698013151787
epoch = 2 with best_valid_accuracy = 78.23698013151787


Train: 100%|██████████| 58/58 [00:39<00:00,  1.48it/s]
Evaluation:  17%|█▋        | 10/58 [00:09<00:38,  1.26it/s]

In [17]:
@torch.inference_mode()
def predict(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    predictions = []
    for x, _ in loader:
        x = x.to(device)
        outputs = model(x)
        y_pred = torch.argmax(outputs, 1)
        predictions.append(y_pred)
    result = torch.cat(predictions)

    return result

In [18]:
np.random.seed(100)
idx = np.random.randint(len(valid_dataset), size=200)

test_dataset = [valid_dataset[i] for i in idx]
test_loader = DataLoader(test_dataset, batch_size=64)

predictions = predict(model, test_loader, device)


In [21]:
predictions.unsqueeze(1).size()

torch.Size([200, 1, 256, 256])

In [None]:
#model.load_state_dict(torch.load('weights_lesson6.pth'))

In [22]:
torch.save(predictions.unsqueeze(1), 'predictions.pth')

In [27]:
predictions_uint8 = predictions.unsqueeze(1).to(torch.uint8)

#print(predictions_uint8)

In [28]:
predictions_uint8.size()

torch.Size([200, 1, 256, 256])

In [29]:
torch.save(predictions_uint8, 'predictions_uint8.pth')