In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
import matplotlib.pyplot as plt
import torch.optim as optim 
import torch.utils.data as Data
import torchvision

In [7]:
class ConvNextBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvNextBlock, self).__init__()
        self.dwconv = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels)
        self.pwconv1 = nn.Conv2d(in_channels, 4 * out_channels, kernel_size=1)
        self.pwconv2 = nn.Conv2d(4 * out_channels, out_channels, kernel_size=1)
        self.norm = nn.BatchNorm2d(in_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.dwconv(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.pwconv2d(out)
        out += residual
        return out

class ConvNeXt(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNeXt, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.block1 = ConvNextBlock(32, 64)
        self.block2 = ConvNextBlock(64, 128)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(128 * 8 * 8, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.block1(x)
        x = self.pool(x)
        x = self.block2(x)
        x = self.pool(x)
        x = x.view(-1, 128 * 8 * 8)
        x = self.fc(x)
        return x

model = ConvNeXt(num_classes=10)

        

In [None]:
class EncoderBlock(nn.Module):
    # Consists of Conv -> ReLU -> MaxPool
    def __init__(self, in_channels, out_channels, layers=2, sampling_factor=2, padding="same"):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.encoder.append(nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding))
        self.encoder.append(nn.ReLU())
        for _ in range(layers - 1):
            self.encoder.append(nn.Conv2d(out_channels, out_channels, 3, 1, padding=padding))
            self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(sampling_factor))
    def forward(self, x):
        for layer in self.encoder:
            x = layer(x)
        mp_out = self.mp(x)
        return mp_out, x
    
class DecoderBlock(nn.Module):
    # Consists of 2x2 transposed convolution -> ReLU
    def __init__(self, in_chans, out_chans, layers=2, skip_connection=True, sampling_factor=2, padding="same"):
        super().__init__()
        skip_factor = 1 if skip_connection else 2
        self.decoder = nn.ModuleList()
        self.tconv = nn.ConvTranspose2d(in_chans, in_chans//2, sampling_factor, sampling_factor)

        self.decoder.append(nn.Conv2d(out_chans, out_chans, 3, 1, padding=padding))
        self.decoder.append(nn.ReLU())

        self.skip_connection = skip_connection
        self.padding = padding

    def forward(self, x, enc_features=None):
        x = self.tconv(x)
        if self.skip_connection:
            if self.padding != "same":
                # Crop the enc_features to the same size as input
                w = x.size(-1)
                c = (enc_features.size(-1) - w) // 2
                enc_features = enc_features[:,:,c:c+w,c:c+w]
            x = torch.cat((enc_features, x), dim=1)
        for dec in self.decoder:
            x = dec(x)
        return x
    
class UNet(nn.Module):
    def __init__(self, nclass=1, in_chans=1, depth=5, layers=2, sampling_factor=2, skip_connection=True, padding="same"):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        out_chans = 64
        for _ in range(depth):
            self.encoder.append(EncoderBlock(in_chans, out_chans, layers, sampling_factor, padding))
            in_chans, out_chans = out_chans, out_chans*2

        out_chans = in_chans // 2
        for _ in range(depth-1):
            self.decoder.append(DecoderBlock(in_chans, out_chans, layers, skip_connection, sampling_factor, padding))
            in_chans, out_chans = out_chans, out_chans//2
        # Add a 1x1 convolution to produce final classes
        self.logits = nn.Conv2d(in_chans, nclass, 1, 1)

    def forward(self, x):
        encoded = []
        for enc in self.encoder:
            x, enc_output = enc(x)
            encoded.append(enc_output)
        x = encoded.pop()
        for dec in self.decoder:
            enc_output = encoded.pop()
            x = dec(x, enc_output)

        # Return the logits
        return self.logits(x)