In [None]:
import os
import joblib
import cv2
import matplotlib.pyplot as plt
import zipfile
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image

In [None]:
RAW_PATH = "../../data/raw/"
PROCESSED_PATH = "../../data/processed/"

In [None]:
def dump(value, filename):
    if value is not None and filename is not None:
        joblib.dump(value=value, filename=filename)
    else:
        raise Exception("Value or filename cannot be None".capitalize())
    
def load(filename):
    if filename is not None:
        return joblib.load(filename)
    else:
        raise Exception("Filename cannot be None".capitalize())
    
def device_init(device = "mps"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        return torch.device("cpu")

In [None]:
class Loader(Dataset):
    def __init__(self, image_path = None, image_size = 128, batch_size = 4, split_ratio = 0.30):
        self.image_path = image_path
        self.batch_size = batch_size
        self.split_ratio = split_ratio
        self.image_size = image_size
        self.channels = 1
        self.images = list()
        self.masks = list()

    def unzip_folder(self):
        if os.path.exists(RAW_PATH):
            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(os.path.join(RAW_PATH, "segmented"))
        else:
            raise Exception("Raw data folder does not exist".capitalize())
    def base_transformation(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def mask_transformation(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.Grayscale(num_output_channels=self.channels)
        ])

    def split_dataset(self, **kwargs):
        images = kwargs["images"]
        masks = kwargs["masks"]

        return train_test_split(images, masks, test_size=self.split_ratio, random_state=42)

    def create_dataloader(self):
        images = os.listdir(os.path.join(RAW_PATH, "segmented"))[0]
        masks = os.listdir(os.path.join(RAW_PATH, "segmented"))[1]

        try:
            images = os.path.join(RAW_PATH, "segmented", images)
            masks = os.path.join(RAW_PATH, "segmented", masks)
        except Exception as e:
            print(e)

        for image in os.listdir(images):
            if image not in os.listdir(masks):
                continue
            else:
                image_path = os.path.join(images, image)
                mask_path = os.path.join(masks, image)

                self.images.append(
                    self.base_transformation()(Image.fromarray(cv2.imread(image_path))))
                self.masks.append(
                    self.mask_transformation()(Image.fromarray(cv2.imread(mask_path))))

        image_split = self.split_dataset(images=self.images, masks=self.masks)

        if os.path.exists(PROCESSED_PATH):

            dataloader = DataLoader(
                dataset=list(zip(self.images, self.masks)),
                batch_size=self.batch_size, shuffle=True)

            train_dataloader = DataLoader(
                dataset=list(zip(image_split[0], image_split[2])),
                batch_size=self.batch_size, shuffle=True)

            test_dataloader = DataLoader(
                dataset=list(zip(image_split[1], image_split[3])),
                batch_size=self.batch_size*4,
                shuffle=True,
            )

            try:
                dump(
                    value=dataloader, filename=os.path.join(PROCESSED_PATH, "dataloader.pkl"))

                dump(
                    value=train_dataloader, filename=os.path.join(PROCESSED_PATH, "train_dataloader.pkl"))

                dump(
                    value=test_dataloader, filename=os.path.join(PROCESSED_PATH, "test_dataloader.pkl"))

            except Exception as e:
                print(e)
        else:
            raise Exception("Processed data folder does not exist".capitalize())

        return dataloader

    @staticmethod
    def details_dataset():
        if os.path.exists(PROCESSED_PATH):
            
            dataloader = load(os.path.join(PROCESSED_PATH, "dataloader.pkl"))
            images, masks = next(iter(dataloader))
            print(
                "Total number of the images in the dataset is {}".format(
                    sum(image.size(0) for image, _ in dataloader)
                )
            )
            print(
                "Total number of the masks in the dataset is {}\n\n".format(
                    sum(masks.size(0) for _, masks in dataloader)
                )
            )
            print("The shape of the images is {}\nThe shape of the masks is {}".format(images.size(), masks.size()))
            
        else:
            raise Exception("Processed data folder does not exist".capitalize())

    @staticmethod
    def data_normalized(**kwargs):
        return (kwargs["data"] - kwargs["data"].min()) / (kwargs["data"].max() - kwargs["data"].min())

    @staticmethod
    def display_images():
        if os.path.exists(PROCESSED_PATH):
            dataloader = load(os.path.join(PROCESSED_PATH, "test_dataloader.pkl"))
            images, masks = next(iter(dataloader))

            plt.figure(figsize=(30, 15))

            for index, image in enumerate(images):
                image = image.permute(1, 2, 0)
                mask = masks[index].permute(1, 2, 0)

                image = Loader.data_normalized(data=image)
                mask = Loader.data_normalized(data=mask)

                plt.subplot(2 * 4, 2 * 4, 2 * index + 1)
                plt.imshow(image)
                plt.title("Image")
                plt.axis("off")

                plt.subplot(2 * 4, 2 * 4, 2 * index + 2)
                plt.imshow(mask)
                plt.title("Mask")
                plt.axis("off")

            plt.tight_layout()
            plt.show()

        else:
            raise Exception("Processed data folder does not exist".capitalize())

In [None]:
loader = Loader(image_path="/Users/shahmuhammadraditrahman/Desktop/brain.zip", batch_size=4, image_size=128, split_ratio=0.30)
loader.unzip_folder()
dataloader = loader.create_dataloader()

In [None]:
Loader.details_dataset()

In [None]:
Loader.display_images()

In [None]:
# train_data, train_label = next(iter(train_dataloader))
# test_data, test_label = next(iter(test_dataloader))


# train_data.size(), train_label.size(), test_data.size(), test_label.size()

#### UNet

In [None]:
from collections import OrderedDict

class Encoder(nn.Module):
    def __init__(self, in_channels=None, out_channels=None):
        super(Encoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.model = self.encoder_block()

    def encoder_block(self):
        layers = OrderedDict()
        layers["conv1"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        layers["relu1"] = nn.ReLU(inplace=True)
        layers["conv2"] = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        layers["batch_norm1"] = nn.BatchNorm2d(self.out_channels)
        layers["relu2"] = nn.ReLU(inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        return self.model(x) if x is not None else None


if __name__ == "__main__":
    encoder = Encoder(in_channels=3, out_channels=64)
    assert encoder(torch.randn(64, 3, 128, 128)).shape == (64, 64, 128, 128)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels=None, out_channels=None):
        super(Decoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.model = self.decoder_block()

    def decoder_block(self):
        layers = OrderedDict()
        layers["deconv1"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=2,
            stride=2,
        )
        return nn.Sequential(layers)

    def forward(self, x=None, skip_info=None):
        if x is not None and skip_info is not None:
            return torch.cat((self.model(x), skip_info), dim=1)
        elif x is not None and skip_info is None:
            return self.model(x)


if __name__ == "__main__":
    encoder = Encoder(in_channels=3, out_channels=64)
    decoder = Decoder(in_channels=64, out_channels=64)

    skip_info = encoder(torch.randn(64, 3, 256, 256))
    noise_samples = torch.randn(64, 64, 128, 128)

    assert decoder(noise_samples, skip_info).shape == (64, 128, 256, 256)

In [None]:
import torch.nn as nn


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder_layer1 = Encoder(in_channels=3, out_channels=64)
        self.encoder_layer2 = Encoder(in_channels=64, out_channels=128)
        self.encoder_layer3 = Encoder(in_channels=128, out_channels=256)
        self.encoder_layer4 = Encoder(in_channels=256, out_channels=512)
        self.bottom_layer = Encoder(in_channels=512, out_channels=1024)

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.intermediate_layer1 = Encoder(in_channels=1024, out_channels=512)
        self.intermediate_layer2 = Encoder(in_channels=512, out_channels=256)
        self.intermediate_layer3 = Encoder(in_channels=256, out_channels=128)
        self.intermediate_layer4 = Encoder(in_channels=128, out_channels=64)

        self.decoder_layer1 = Decoder(in_channels=1024, out_channels=512)
        self.decoder_layer2 = Decoder(in_channels=512, out_channels=256)
        self.decoder_layer3 = Decoder(in_channels=256, out_channels=128)
        self.decoder_layer4 = Decoder(in_channels=128, out_channels=64)

        self.final_layer = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1), nn.Sigmoid()
        )

    def forward(self, x):
        # Encoder path
        enc1_out = self.encoder_layer1(x)
        pooled_enc1 = self.max_pool(enc1_out)

        enc2_out = self.encoder_layer2(pooled_enc1)
        pooled_enc2 = self.max_pool(enc2_out)

        enc3_out = self.encoder_layer3(pooled_enc2)
        pooled_enc3 = self.max_pool(enc3_out)

        enc4_out = self.encoder_layer4(pooled_enc3)
        pooled_enc4 = self.max_pool(enc4_out)

        bottom_out = self.bottom_layer(pooled_enc4)

        # Decoder path
        dec1_input = self.decoder_layer1(bottom_out, enc4_out)
        dec1_out = self.intermediate_layer1(dec1_input)

        dec2_input = self.decoder_layer2(dec1_out, enc3_out)
        dec2_out = self.intermediate_layer2(dec2_input)

        dec3_input = self.decoder_layer3(dec2_out, enc2_out)
        dec3_out = self.intermediate_layer3(dec3_input)

        dec4_input = self.decoder_layer4(dec3_out, enc1_out)
        dec4_out = self.intermediate_layer4(dec4_input)

        # Final output
        final_output = self.final_layer(dec4_out)

        return final_output

#### AttentionNet

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None):
        super(AttentionBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.W_gate = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            ),
            nn.BatchNorm2d(self.out_channels),
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            ),
            nn.BatchNorm2d(self.out_channels),
        )

        self.psi = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels * 2,
                out_channels=1,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            ),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
        )

        self.relu = nn.ReLU(inplace=True)


    def forward(self, x, skip_info):
        transformed_input = self.W_gate(x)

        if skip_info is not None:
            transformed_skip = self.W_x(skip_info)

        merged_features = self.relu(torch.cat((transformed_input, transformed_skip), dim=1))
        attention_weights = self.psi(merged_features)

        return transformed_skip * attention_weights
    
if __name__ == "__main__":
    out_result = torch.randn(64, 512, 16, 16)
    skip_info = torch.randn(64, 512, 16, 16)
    
    attention = AttentionBlock(in_channels=512, out_channels=512)
    assert attention(out_result, skip_info).shape == (64, 512, 16, 16)

In [None]:
class AttentionUNet(nn.Module):
    def __init__(self):
        super(AttentionUNet, self).__init__()

        self.encoder_layer1 = Encoder(in_channels=3, out_channels=64)
        self.encoder_layer2 = Encoder(in_channels=64, out_channels=128)
        self.encoder_layer3 = Encoder(in_channels=128, out_channels=256)
        self.encoder_layer4 = Encoder(in_channels=256, out_channels=512)
        self.bottom_layer = Encoder(in_channels=512, out_channels=1024)

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.intermediate_layer1 = Encoder(in_channels=1024, out_channels=512)
        self.intermediate_layer2 = Encoder(in_channels=512, out_channels=256)
        self.intermediate_layer3 = Encoder(in_channels=256, out_channels=128)
        self.intermediate_layer4 = Encoder(in_channels=128, out_channels=64)

        self.decoder_layer1 = Decoder(in_channels=1024, out_channels=512)
        self.decoder_layer2 = Decoder(in_channels=512, out_channels=256)
        self.decoder_layer3 = Decoder(in_channels=256, out_channels=128)
        self.decoder_layer4 = Decoder(in_channels=128, out_channels=64)

        self.attention_layer1 = AttentionBlock(in_channels=512, out_channels=512)
        self.attention_layer2 = AttentionBlock(in_channels=256, out_channels=256)
        self.attention_layer3 = AttentionBlock(in_channels=128, out_channels=128)
        self.attention_layer4 = AttentionBlock(in_channels=64, out_channels=64)

        self.final_layer = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1), nn.Sigmoid()
        )


    def forward(self, x):
        # Encoder layers
        encoder_1_output = self.encoder_layer1(x)
        pooled_encoder_1_output = self.max_pool(encoder_1_output)

        encoder_2_output = self.encoder_layer2(pooled_encoder_1_output)
        pooled_encoder_2_output = self.max_pool(encoder_2_output)

        encoder_3_output = self.encoder_layer3(pooled_encoder_2_output)
        pooled_encoder_3_output = self.max_pool(encoder_3_output)

        encoder_4_output = self.encoder_layer4(
            pooled_encoder_3_output
        )
        pooled_encoder_4_output = self.max_pool(encoder_4_output)

        bottom_layer_output = self.bottom_layer(
            pooled_encoder_4_output
        )
        decoder_1_input = self.decoder_layer1(bottom_layer_output)

        attention_1_output = self.attention_layer1(decoder_1_input, encoder_4_output)
        merged_attention_1_output = torch.cat((attention_1_output, encoder_4_output), dim=1)
        decoder_1_output = self.intermediate_layer1(merged_attention_1_output)

        decoder_2_input = self.decoder_layer2(decoder_1_output)
        attention_2_output = self.attention_layer2(decoder_2_input, encoder_3_output)
        merged_attention_2_output = torch.cat((attention_2_output, encoder_3_output), dim=1)
        decoder_2_output = self.intermediate_layer2(merged_attention_2_output)

        decoder_3_input = self.decoder_layer3(decoder_2_output)
        attention_3_output = self.attention_layer3(decoder_3_input, encoder_2_output)
        merged_attention_3_output = torch.cat((attention_3_output, encoder_2_output), dim=1)
        decoder_3_output = self.intermediate_layer3(merged_attention_3_output)

        decoder_4_input = self.decoder_layer4(decoder_3_output)
        attention_4_output = self.attention_layer4(decoder_4_input, encoder_1_output)
        merged_attention_4_output = torch.cat((attention_4_output, encoder_1_output), dim=1)
        decoder_4_output = self.intermediate_layer4(merged_attention_4_output)

        # Final layer
        final_output = self.final_layer(decoder_4_output)

        return final_output

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        
        self.smooth = smooth
        
    def forward(self, predicted, actual):
        predicted = predicted.contiguous().view(-1)
        actual = actual.contiguous().view(-1)
        
        intersection = (predicted * actual).sum()
        
        dice =  (2.0 * intersection + self.smooth) / (predicted.sum() + actual.sum() + self.smooth)
        
        return dice

In [None]:
class JaccardLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(JaccardLoss, self).__init__()
        
        self.smooth = smooth
        
    def forward(self, predicted, actual):
        predicted = predicted.contiguous().view(-1)
        actual = actual.contiguous().view(-1)
        
        intersection = (predicted * actual).sum()
        union = (predicted + actual).sum() - intersection
        
        return (1.0 - intersection + self.smooth) / (union + self.smooth)

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        super(FocalLoss, self).__init__()
        
        self.gamma = gamma
        self.alpha = alpha
        
    def forward(self, predicted, actual):
        predicted = predicted.contiguous().view(-1)
        actual = actual.contiguous().view(-1)
        
        criterion = nn.BCELoss()
        BCE = criterion(predicted, actual)
        pt = torch.exp(-BCE)
        
        return self.alpha * (1 - pt) ** self.gamma * BCE

In [None]:
class ComboLoss(nn.Module):
    def __init__(self, smooth = 1e-6, alpha=0.25, gamma=2):
        super(ComboLoss, self).__init__()
        
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, predicted, actual):
        predicted = predicted.contiguous().view(-1)
        actual = actual.contiguous().view(-1)
        
        criterion = nn.BCELoss()
        BCE = criterion(predicted, actual)
        
        pt = torch.exp(-BCE)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE
        
        intersection = (predicted * actual).sum()
        
        dice_loss = 1 - (2.0 * intersection + self.smooth) / (predicted.sum() + actual.sum() + self.smooth)
        
        return (dice_loss + focal_loss) + BCE

In [None]:
def helpers(**kwargs):
    is_attentionUNet = kwargs["is_attentionUNet"]
    device = device=kwargs["device"]
    loss = kwargs["loss"]

    
    if is_attentionUNet == True:
        model = AttentionUNet().to(device)
    else:
        model = UNet().to(device)
        
    optimizer = optim.Adam(model.parameters(), lr=kwargs["lr"], betas=(0.5, 0.999))
    
    if os.path.exists(PROCESSED_PATH):
        train_dataloader = load(filename=os.path.join(PROCESSED_PATH, "train_dataloader.pkl"))
        test_dataloader = load(filename=os.path.join(PROCESSED_PATH, "test_dataloader.pkl"))
    else:
        raise Exception("Dataloader - train & test cannot be loaded".capitalize())
    
    if loss == "dice":
        criterion = DiceLoss(smooth=kwargs["smooth"])
    elif loss == "dice_bce":
        criterion = DiceLoss(smooth=kwargs["smooth"])
    elif loss == "IoU":
        criterion = JaccardLoss(smooth=kwargs["smooth"])
    elif loss == "focal":
        criterion = FocalLoss(gamma=kwargs["gamma"], alpha=kwargs["alpha"])
    elif loss == "combo":
        criterion = ComboLoss(smooth=kwargs["smooth"], alpha=kwargs["alpha"], gamma=kwargs["gamma"])
    else:
        criterion = nn.BCELoss()
    
    return {"model": model,
            "optimizer": optimizer,
            "criterion": criterion,
            "train_dataloader": train_dataloader,
            "test_dataloader": test_dataloader
            }

In [None]:
import numpy as np

class Trainer:
    def __init__(
        self, epochs = 50,
        lr = 1e-2,
        loss = None,
        is_attentionUNet = False,
        smooth = 0.01,
        alpha = 0.25,
        gamma = 2,
        beta1 = 0.9,
        beta2 = 0.999,
        device = "mps",
        display = True
        ):

        self.epochs = epochs
        self.lr = lr
        self.loss = loss
        self.is_attentionUNet = is_attentionUNet
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma
        self.beta1 = beta1
        self.beta2 = beta2
        self.device = device_init(device)
        self.display = display

        self.setup = helpers(
                is_attentionUNet = self.is_attentionUNet,
                device = self.device,
                lr = self.lr,
                loss = self.loss,
                smooth = self.smooth,
                alpha = self.alpha,
                gamma = self.gamma)

        self.model = self.setup["model"]
        self.optimizer = self.setup["optimizer"]
        self.criterion = self.setup["criterion"]
        self.train_dataloader = self.setup["train_dataloader"]
        self.test_dataloader = self.setup["test_dataloader"]

    def update_train(self, **kwargs):
        self.optimizer.zero_grad()

        train_predicted_masks = self.model(kwargs["images"])

        train_predicted_loss = self.criterion(train_predicted_masks, kwargs["masks"])

        train_predicted_loss.backward()

        self.optimizer.step()

        return train_predicted_loss.item()

    def update_test(self, **kwargs):
        test_predicted_masks = self.model(kwargs["images"])

        test_predicted_loss = self.criterion(test_predicted_masks, kwargs["masks"])

        return test_predicted_loss.item()

    def train(self):

        for epoch in tqdm(range(self.epochs)):
            total_train_loss = list()
            total_test_loss = list()

            for images, masks in self.train_dataloader:
                images = images.to(self.device)
                masks = masks.to(self.device)

                total_train_loss.append(self.update_train(images=images, masks=masks))

            for images, masks in self.test_dataloader:
                images = images.to(self.device)
                masks = masks.to(self.device)

                total_test_loss.append(self.update_test(images=images, masks=masks))

            print("Epoch - [{}/{}] - train_loss: {:.5f} - test_loss: {:.5f}".format(
                epoch+1, self.epochs, np.mean(total_train_loss), np.mean(total_test_loss)))

            image, masks = next(iter(self.train_dataloader))
            image = image.to(self.device)
            masks = masks.to(self.device)

            pred_masks = self.model(image)
            save_image(
                pred_masks,
                os.path.join(
                    "../../outputs/train_images/image_{}.png".format(epoch + 1)
                ),
                normalize=True,
            )

        torch.save(self.model.state_dict(), "../../checkpoints/best_model/model.pth")

In [None]:
trainer = Trainer(epochs=15, lr=1e-4, loss="tversky", smooth=0.001, device="mps", display=True)

trainer.train()

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = UNet().to(device)
model.load_state_dict(torch.load("../../checkpoints/best_model/model.pth"))

In [None]:
class Test:
    def __init__(self, device = "mps"):
        self.device = device_init(device)

    def plot(self):        
        if os.path.exists(PROCESSED_PATH):
            dataloader = load(filename=os.path.join(PROCESSED_PATH, "test_dataloader.pkl"))
        else:
            raise Exception("Dataloader - test cannot be loaded".capitalize())

        images, masks = next(iter(dataloader))
        pred_masks = model(images.to(self.device))

        plt.figure(figsize=(30, 15))

        for index, image in enumerate(pred_masks):
            image = image.permute(1, 2, 0)
            image = image.cpu().detach().numpy()
            image = (image - image.min())/(image.max() - image.min())

            plt.subplot(2*4, 2*4, 2*index+1)
            plt.imshow(image, cmap="gray")
            plt.title("Image")
            plt.axis("off")

            mask = masks[index].permute(1, 2, 0)
            mask = mask.cpu().detach().numpy()
            mask = (mask - mask.min())/(mask.max() - mask.min())

            plt.subplot(2 * 4, 2 * 4, 2 * index + 2)
            plt.imshow(mask, cmap="gray")
            plt.title("Ground Truth")
            plt.axis("off")

        plt.show()

In [None]:
test = Test(device="mps")
test.plot()

#### AttentionUNet

In [None]:
trainer = Trainer(
    epochs=15, lr=1e-4, loss="tversky", is_attentionUNet=True, smooth=0.001, device="mps", display=True
)

trainer.train()

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = AttentionUNet().to(device)

model.load_state_dict(torch.load("../../checkpoints/best_model/model.pth"))
test.plot()