In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

: 

In [None]:
device

: 

In [None]:
def crop_tensor(input,target):
    diff =input.size()[2] - target.size()[2] # difference in width
    diff = diff // 2
    #print("Difffffffffffffff: ",diff,input.size()[2],target.size()[2])
    return(input[:,:,diff:input.size()[2]-diff,diff:input.size()[2]-diff])

: 

In [None]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU()
    )

: 

In [None]:
class UNet(nn.Module):
    def __init__(self,IN_CHANNELS=3):
        super().__init__() 
        self.encode_conv1 = conv_block(IN_CHANNELS, 64)
        self.encode_conv2 = conv_block(64, 128)
        self.encode_conv3 = conv_block(128, 256)
        self.encode_conv4 = conv_block(256, 512)
        self.encode_conv5 = conv_block(512, 1024)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_transpose1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        self.decode_conv1 = conv_block(1024, 512)
        self.conv_transpose2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
        self.decode_conv2 = conv_block(512, 256)
        self.conv_transpose3 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.decode_conv3 = conv_block(256, 128)
        self.conv_transpose4 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.decode_conv4 = conv_block(128, 64)
        self.out = nn.Conv2d(in_channels=64,out_channels=7,kernel_size=1)
     
    def forward(self, x):
        # Encoder
        x1 = self.encode_conv1(x)
        x2 = self.maxpool(x1)
        x3 = self.encode_conv2(x2)
        x4 = self.maxpool(x3)
        x5 = self.encode_conv3(x4)
        x6 = self.maxpool(x5)
        x7 = self.encode_conv4(x6)
        x8 = self.maxpool(x7)
        x9 = self.encode_conv5(x8)

        
        # Decoder x1,x3,x5,x7 will be used as input 
        x10 = self.conv_transpose1(x9)
        x7_cropped = crop_tensor(x7,x10)
        x7_10 = torch.cat([x10,x7_cropped],1)
        x11 = self.decode_conv1(x7_10)
        x12 = self.conv_transpose2(x11)
        x5_cropped = crop_tensor(x5,x12)
        x5_12 = torch.cat([x12,x5_cropped],dim=1)
        x13 = self.decode_conv2(x5_12)
        x14 = self.conv_transpose3(x13)
        x3_cropped = crop_tensor(x3,x14)
        x3_14 = torch.cat([x14,x3_cropped],dim=1)
        x15 = self.decode_conv3(x3_14)
        x16 = self.conv_transpose4(x15)
        x1_cropped = crop_tensor(x1,x16)
        x1_16 = torch.cat([x16,x1_cropped],dim=1)
        x17 = self.decode_conv4(x1_16)
        out = self.out(x17)
        print(" out : ",out.size())
        return out

: 

In [None]:
model = UNet().to(device)

: 

In [None]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss



: 

In [None]:
# image = torch.rand((1,3,300,300)).to(device)
# print(model(image))

: 

In [None]:
# batchsize, channels,height,width
# summary(model, input_size = (3,300, 300),batch_size=200)

: 