In [1]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data = datasets.FGVCAircraft(root='cnn_olav', download=True, transform=transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]))

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down1 = DoubleConv(3, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)
        self.down4 = DoubleConv(256, 512)
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(256 + 256, 128, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(128 + 128, 64, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose2d(64 + 64, 32, kernel_size=2, stride=2)
        self.final = nn.Conv2d(32, 100, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.maxpool(x1)
        x2 = self.down2(x2)
        x3 = self.maxpool(x2)
        x3 = self.down3(x3)
        x4 = self.maxpool(x3)
        x4 = self.down4(x4)
        x = self.up1(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.up2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up3(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up4(x)
        x = self.final(x)
        return x



In [4]:
unet = UNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(unet.parameters(), lr=0.001, momentum=0.9)