In [97]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv3d(in_c, out_c, kernel_size=3, padding="same")
        self.bn1 = nn.BatchNorm3d(out_c)

        self.conv2 = nn.Conv3d(out_c, out_c, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm3d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool3d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose3d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class build_unet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(1, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        """ Bottleneck """
        self.b = conv_block(512, 1024)

        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        """ Classifier """
        self.outputs = nn.Conv3d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        
        print(inputs.shape)
        
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)

        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)
        
        print(outputs.shape)

        return outputs



In [62]:
import os
import time
from glob import glob

import torch
from torch.utils.data import DataLoader
import torch.nn as nn


from loss import DiceLoss, DiceBCELoss
from utils import seeding, create_dir, epoch_time

def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss


def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss


In [63]:
from read_write_mrc import read_mrc
import numpy as np


img_directory = 'slices/images'
mask_directory = 'slices/masks'



img_name_list = os.listdir(img_directory)
mask_name_list =  os.listdir(mask_directory)


img_name_list.sort()
mask_name_list.sort()

# num_images = len(img_name)

X_list = []
Y_list = []

zeros = 0
nonzeros= 0

for img_name,mask_name in zip(img_name_list, mask_name_list):
    # assert str(img_name[0:-4]+str('_labels.mrc')) == str(mask_name), "Mask doesn't match to image"
    
    assert len(img_name_list) == len(mask_name_list), "Number of images != number of masks"
    
    
    img = read_mrc(os.path.join(img_directory, img_name))
    mask = read_mrc(os.path.join(mask_directory, mask_name))
    
    mask = np.where((mask == 5), np.ones(mask.shape), np.zeros(mask.shape))
    
    assert img.shape == mask.shape, "Mask shape does not match to image shape."
    
    assert img.shape == (64,704,704), "Wrong shape for image"
    assert mask.shape == (64,704,704), "Wrong shape for mask"

    
    if( ( np.count_nonzero(img)==0 )  or  ( np.count_nonzero(mask) == 0)  ):
        continue
        
    X_list.append(img)
    Y_list.append(mask)

In [136]:
X_train = X_list[1]
Y_train = Y_list[1]

X_test = X_list[2]
Y_test = Y_list[2]

X = np.asarray(X_list)
Y = np.asarray(Y_list)


X_train = np.stack((X_train,)*1, axis = 0)
y_train = np.expand_dims(Y_train, axis = 0)

X_test = np.stack((X_test,)*1, axis = 0)
y_test = np.expand_dims(Y_test, axis = 0)




# # X = X[1,2]
# # Y = Y[1,2]

# print(X.shape)
# print(Y.shape)

In [105]:
train_img = np.stack((X,)*1, axis = 1)
train_mask = np.expand_dims(Y, axis = 1)

train_img = torch.from_numpy(train_img)
train_mask = torch.from_numpy(train_mask)

# train_mask_cat = to_categorical(train_mask , num_classes = num_classes)

In [106]:
train_img.shape

torch.Size([9, 1, 64, 704, 704])

In [107]:
train_mask.shape

torch.Size([9, 1, 64, 704, 704])

In [108]:
from sklearn.model_selection import train_test_split


X_train, X_test, y_train, y_test = train_test_split(train_img, train_mask, test_size = 0.27, random_state =0 )

In [137]:
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(1, 704, 704)
(1, 704, 704)
(1, 704, 704)
(1, 704, 704)


In [135]:
# train_dataset = [X_train, y_train]

In [111]:
# valid_dataset = [X_test, y_test]

In [125]:
train_loader = DataLoader(dataset = [(X_train, y_train)], batch_size = 1, shuffle = True, num_workers = 2)

In [126]:
valid_loader= DataLoader(dataset = [(X_test, y_test)], batch_size = 1, shuffle = False, num_workers = 2)

In [127]:
device = torch.device('cpu')
device

device(type='cpu')

In [128]:
model = build_unet()

In [129]:
model = model.to(device)

In [130]:
lr = 0.01

model = build_unet()

optimizer = torch.optim.Adam(model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

best_valid_los = float('inf')

In [131]:
num_epochs = 20


for epoch in range(num_epochs):
        start_time = time.time()

        train_loss = train(model, train_loader, optimizer, loss_fn, device)
        valid_loss = evaluate(model, valid_loader, loss_fn, device)

        """ Saving the model """
        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
            print(data_str)

            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_path)

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
        data_str += f'\tTrain Loss: {train_loss:.3f}\n'
        data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
        print(data_str)

torch.Size([1, 704, 704])


RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 704, 704]

In [32]:
dataset = [(X_train, y_train)]

In [33]:
type(dataset)

list

In [34]:
for i,j in dataset:
    print(i.shape, j.shape)

(6, 64, 704, 704) (6, 64, 704, 704)
