In [8]:
import torch 
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

In [6]:
class UNet(nn.Module):
    def __init__(self):
        # Super constructor
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding='same')
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=24, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=24, out_channels=32, kernel_size=3, padding='same')
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=48, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, padding='same')
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, padding='same')
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=192, out_channels=256, kernel_size=3, padding='same')
        )

        # Decoder 
        
        self.transpose_conv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3)

        self.conv6 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=192, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3, padding='same')
        )
        self.transpose_conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3)

        self.conv7 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=92, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=92, out_channels=64, kernel_size=3, padding='same')
        )
        self.transpose_conv3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3)

        self.conv8 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=48, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=48, out_channels=32, kernel_size=3, padding='same')
        )
        self.transpose_conv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3)

        self.conv9 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=24, kernel_size=3, padding='same'),
            nn.Conv2d(in_channels=24, out_channels=16, kernel_size=3, padding='same')
        )

        self.conv10 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)
        

    def forward(self, images):
        # Encoder

        c1 = self.conv1(images)
        p1 = self.pool1(c1)

        c2 = self.conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.conv3(p2)
        p3 = self.pool3(c3)

        c4 = self.conv4(p3)
        p4 = self.pool4(c4)

        c5 = self.conv5(p4)

        # Decoder 

        u6 = self.transpose_conv1(c5)
        u6 = torch.cat(u6, c4, dim=2)

        c6 = self.conv6(u6)
        u7 = self.transpose_conv2(c6)
        u7 = torch.cat(u7, c3, dim=2)

        c7 = self.conv7(u7)
        u8 = self.transpose_conv3(c7)
        u8 = torch.cat(u8, c2)

        c8 = self.conv8(u8)
        u9 = self.transpose_conv4(c8)
        u9 = torch.cat(u9, c1)

        c9 = self.conv9(u9)
        
        return self.conv10(c9)



In [7]:
UNet = UNet()

#### Training

https://www.kaggle.com/c/tgs-salt-identification-challenge/data

In [None]:
N_EPOCHS = 20
LR = 0.005
BATCH_SIZE = 32
optimizer = Adam(UNet.parameters(), lr=LR)
criterion = CrossEntropyLoss()