In [30]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
 
class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                )
        return block
    
    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
                    )
            return  block
    
    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    )
            return  block
    
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
                            )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)
        return  final_layer

In [31]:
num_epochs = 50
unet = UNet(in_channel=1,out_channel=2)
#out_channel represents number of segments desired
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)


for epoch in range(num_epochs):  # loop over the dataset multiple times
    optimizer.zero_grad() 
    outputs = unet(inputs)
    # permute such that number of desired segments would be on 4th dimension
    outputs = outputs.permute(0, 2, 3, 1)
    m = outputs.shape[0]
    # Resizing the outputs and label to caculate pixel wise softmax loss
    outputs = outputs.resize(m*width_out*height_out, 2)
    labels = labels.resize(m*width_out*height_out)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

NameError: name 'inputs' is not defined

In [None]:
# for epoch in range(num_epochs):  # loop over the dataset multiple times
    
#     # for train:
#     running_loss = 0.0
#     # for valiation:
#     dev_loss = 0
#     predict_value = torch.zeros(1) # valiation predict:
#     true_value = torch.zeros(1) # true lable of valiation part:
    
#     for index, (inputs, labels, patient_name) in enumerate(tqdm(train_data_loader)):
        
#         names_box_train.extend(patient_name)
#         labels_box_train.extend(labels.cpu().numpy().tolist())
        
#         if index<=110:# train part: 
            
#             model.train()
            
#             # zero the parameter gradients
#             optimizer.zero_grad()
        
# #             inputs, labels = inputs.cuda(), labels.cuda()
#             inputs, labels = inputs.to(device), labels.to(device)
        
#             # forward + backward + optimize
#             inputs = inputs.unsqueeze(dim=1).float()
        
#             inputs = F.interpolate(inputs, size=[sample_size,sample_size,sample_size],mode='trilinear',align_corners=False)
        
#             outputs = model(inputs)
        
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
        
#             # print statistics
#             running_loss += loss.item()
            
#         else: # valiation part:
            
#             model.eval()
            
# #             inputs, labels = inputs.cuda(), labels.cuda()
#             inputs, labels = inputs.to(device), labels.to(device)
        
#             # forward + backward + optimize
#             inputs = inputs.unsqueeze(dim=1).float()
        
#             inputs = F.interpolate(inputs, size=[sample_size,sample_size,sample_size],mode='trilinear',align_corners=False)
        
#             outputs = model(inputs)
        
#             test_value = F.softmax(outputs)
#             test_value_1 = test_value[:,1]
#             test_value_1 = test_value_1.cpu()
#             predict_value = torch.cat((predict_value,test_value_1))
    
#             test_value_2 = labels
#             test_value_2 = test_value_2.cpu().float()
#             true_value = torch.cat((true_value, test_value_2))
            
#     print('[%d] loss: %.5f' %(epoch + 1, running_loss / 384))
# #     writer.add_scalar('Loss', running_loss, epoch)
#     logger.log_value('loss',running_loss,epoch)
#     running_loss = 0.0