In [4]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import random
import os
import matplotlib.pylab as plt
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [18]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH = path = 'deep-learning-final-project-lung-cancer-tumor-segmentation\\data\\processed\\' 
TRAIN_IMG_PATH = path + 'train\\original_png\\'
TRAIN_MASK_PATH = path + 'train\\mask_png\\'
TEST_IMG_PATH = path + 'test\\original_png\\'
TEST_MASK_PATH = path + 'test\\mask_png\\'
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
BATCH_SIZE = 16
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = True


In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64,128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
    
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape)
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
    
        return self.final_conv(x)
        

In [7]:
class CTDataset(Dataset):
    def __init__(self, image_path, mask_path, transform=None):
        self.image_path = image_path
        self.mask_path = mask_path
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.image_path))

    def __getitem__(self, index):
        img_name = f'{index}.png'
        img_path = os.path.join(self.image_path, img_name)
        mask_path = os.path.join(self.mask_path, img_name)  # Assuming mask files have the same name as image files with "_mask" suffix

        image = np.array(Image.open(img_path).convert("L"), dtype=np.float32)
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [8]:
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        # A.Normalize(
        #     mean=[0.0, 0.0, 0.0],
        #     std=[1.0, 1.0, 1.0],
        #     max_pixel_value=255.0,
        # ),
        ToTensorV2()
    ]
)

val_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        # A.Normalize(
        #     mean=[0.0, 0.0, 0.0],
        #     std=[1.0, 1.0, 1.0],
        #     max_pixel_value=255.0,
        # ),
        ToTensorV2()
    ]
)                                           

In [14]:
def save_checkpoint(state, path=PATH, filename="model.pt"):
    print("=> Saving checkpoint")
    torch.save(state, path + filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(train_path, train_mask_path, val_path, val_mask_path, batch_size, train_transform, val_transform, num_workers=4, pin_memory=True):
    train_ds = CTDataset(image_path=train_path, mask_path=train_mask_path, transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, pin_memory=pin_memory, shuffle=True)
    val_ds = CTDataset(image_path=val_path, mask_path=val_mask_path, transform=val_transform)
    val_loader = DataLoader(val_ds, batch_size=batch_size,pin_memory=pin_memory, shuffle=False)
    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x,y in tqdm(loader):
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predicts_as_imgs(loader, model, folder='saved_images\\', device="cuda"):
    model.eval()
    for idx, (x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(preds, f"{PATH + folder}pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{PATH + folder}{idx}.png")

    model.train()

In [10]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        # with torch.cuda.amp.autocast():
        
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()
        # loss = criterion(y_pred, mask)
        loss.backward()
        optimizer.step()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item())
    

In [21]:
model = UNET(in_channels=1, out_channels=1).to(device=DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_PATH,
    TRAIN_MASK_PATH,
    TEST_IMG_PATH,
    TEST_MASK_PATH,
    BATCH_SIZE,
    train_transform,
    val_transform
)

if LOAD_MODEL:
    load_checkpoint(torch.load(PATH + 'model.pt'), model)

check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    
    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

    #print some examples to a folder
    save_predicts_as_imgs(val_loader, model, folder="saved_images\\", device=DEVICE)

=> Loading checkpoint


100%|██████████| 17/17 [01:41<00:00,  5.99s/it]


Got 17285691/17301504 with acc 99.91
Dice score: 0.0


100%|██████████| 45/45 [12:41<00:00, 16.93s/it, loss=0.0971]


=> Saving checkpoint


100%|██████████| 17/17 [01:36<00:00,  5.66s/it]


Got 17286146/17301504 with acc 99.91
Dice score: 0.0


100%|██████████| 45/45 [11:22<00:00, 15.18s/it, loss=0.0876]


=> Saving checkpoint


100%|██████████| 17/17 [01:24<00:00,  4.96s/it]


Got 17286146/17301504 with acc 99.91
Dice score: 0.0


 20%|██        | 9/45 [02:31<10:05, 16.83s/it, loss=0.0881]


KeyboardInterrupt: 

In [None]:
mask_one = Image.open(train_image_path + "0.png")
mask_one

In [None]:
transform = transforms.Compose([ 
    transforms.PILToTensor() 
]) 
  
# transform = transforms.PILToTensor() 
# Convert the PIL image to Torch tensor 
mask_one_tensor = transform(mask_one)
mask_one_tensor[mask_one_tensor==255] = 1
# mask_one_tensor.float().unsqueeze(1)
mask_one_tensor.min(), mask_one_tensor.max(), mask_one_tensor.shape

In [None]:
x = torch.randn((3,1, 160, 160))
model = UNET(in_channels=1, out_channels=1)
preds = model(x)
print(preds.shape)
print(x.shape)
assert preds.shape == x.shape

In [None]:
a=[1,1,1]
b=np.array([0,1,0])
(b > 0).sum()