![](./diagram.JPG)

In [1]:
from custom_functions import *

cuda:0


In [2]:
data_loader, data = load_med_data(batch_size = 1)

In [3]:
class Conv_layer(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 64, kernel_size = 3):
        super(Conv_layer, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, padding = "same", bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = kernel_size, padding = "same",bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x

class Unet(nn.Module):
        def __init__(self, in_dims = [(3,64),(64,128),(128,256),(256,512)], out_dims = [(1024,512),(512,256),(256,128),(128,64)], out_channels=3):
            super(Unet, self).__init__()
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.sig = nn.Sigmoid()
            
            self.down_scale = nn.ModuleList()
            self.up_scale_layer = nn.ModuleList()
            self.up_scale_conv = nn.ModuleList()
            
            for i in in_dims:
                self.down_scale.append(Conv_layer(in_channels=i[0], out_channels =i[1]))
            
            for i in out_dims:
                self.up_scale_conv.append(nn.ConvTranspose2d(in_channels = i[0], out_channels = i[1], kernel_size =2, stride=2))
                self.up_scale_layer.append(Conv_layer(in_channels=i[0], out_channels =i[1]))
            
            self.bottom_layer = Conv_layer(in_dims[-1][1] , out_dims[0][0])                                     
            self.final_conv = nn.Conv2d(in_channels = out_dims[-1][1], out_channels = out_channels, kernel_size = 1, padding = "same")
            
        def forward(self, x):
            
            cat_tensors = []
                                     
            for fxn in self.down_scale:
                x = fxn(x)
                cat_tensors.append(x)
                x = self.pool(x)
            
            x = self.bottom_layer(x)
            
            for index, fxn in enumerate(self.up_scale_conv):
                x = self.up_scale_conv[index](x)
                x = self.up_scale_layer[index](torch.cat((cat_tensors[-1*(index+1)],x),1))
            
            x = self.final_conv(x)
            return self.sig(x)
            
    

In [4]:
from torch.optim import Adam
import os
model = Unet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01)

In [5]:
num_epochs = 100

In [6]:
def train_epoch(epoch_index):

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(data_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

    return epoch_index

In [7]:
for i in range(num_epochs):
    train_epoch(i)
    os.system("cls")
    print(i/num_epochs)

0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
0.1
0.11
0.12
0.13
0.14
0.15
0.16
0.17
0.18
0.19
0.2
0.21
0.22
0.23
0.24
0.25
0.26
0.27
0.28
0.29
0.3
0.31
0.32
0.33
0.34
0.35
0.36
0.37
0.38
0.39
0.4
0.41
0.42
0.43
0.44
0.45
0.46
0.47
0.48
0.49
0.5
0.51
0.52
0.53
0.54
0.55
0.56
0.57
0.58
0.59
0.6
0.61
0.62
0.63
0.64
0.65
0.66
0.67
0.68
0.69
0.7
0.71
0.72
0.73
0.74
0.75
0.76
0.77
0.78
0.79
0.8
0.81
0.82
0.83
0.84
0.85
0.86
0.87
0.88
0.89
0.9
0.91
0.92
0.93
0.94
0.95
0.96
0.97
0.98
0.99


In [None]:
plt.imshow( transforms.functional.to_pil_image(data[1][0]))
plt.show()
plt.imshow( transforms.functional.to_pil_image(data[1][1]))
plt.show()
plt.imshow( transforms.functional.to_pil_image(model(data[1][0].reshape(1,3,512,512)).reshape((3,512,512))))
plt.show()