In [1]:
import argparse
import shutil
import random
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from models.vanilla import VanillaModel
from torch import optim, nn
from data.mri_dataset import SliceData
from torch.utils.data import DataLoader
from data.mri_dataset import DataTransform
from data import transforms
from models.subsampling import SubsamplingLayer
import gc


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn, device, mask_lr, results_root, drop_rate, learn_mask=False):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device
        self.mask_lr = mask_lr
        self.results_root = results_root
        self.early_stopping_patience = 5
        self.best_val_loss = float('inf')
        self.no_improve_epochs = 0
        self.best_model_state = None
        self.learn_mask = learn_mask
        self.drop_rate = drop_rate
        self.train_psnr_mean = 0
        self.train_psnr_std = 0
        self.test_psnr_mean = 0
        self.test_psnr_std = 0
        self.train_losses = []
        self.val_losses = []
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=2, verbose=True)

    def fit(self, train_loader, val_loader, epochs, i = 0):
        for epoch in range(epochs):
            train_loss, train_psnr, train_psnr_std = self.train_epoch(train_loader)
            val_loss, val_psnr, val_psnr_std = self.evaluate(val_loader)
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            print(f'Epopch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}, Train PSNR: {train_psnr}, Val PSNR: {val_psnr} ')

            self.scheduler.step(val_loss)


            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model_state = self.model.state_dict()
                self.no_improve_epochs = 0
            else:
                self.no_improve_epochs += 1
                if self.no_improve_epochs >= self.early_stopping_patience:
                    print('Early stopping triggered.')
                    break

        if self.best_model_state:
            self.model.load_state_dict(self.best_model_state)

        self.train_psnr_mean, self.train_psnr_std = train_psnr, train_psnr_std
        self.plot_losses(i)


    def train_epoch(self, loader):
        self.model.train()
        total_loss = 0
        psnr_values = []
        
        for inputs, targets in loader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            loss.backward()
            self.optimizer.step()

            if self.model.subsample.learn_mask:
                self.model.subsample.mask_grad(self.mask_lr)
            
            total_loss += loss.item()
            psnr_values.append(self.calculate_psnr(outputs, targets))

            del inputs, outputs, targets, loss
            torch.cuda.empty_cache()

        avg_loss = total_loss / len(loader)
        avg_psnr = np.mean(psnr_values)
        psnr_std = np.std(psnr_values)
        return avg_loss, avg_psnr, psnr_std

    def evaluate(self, loader):
        self.model.eval()
        total_loss = 0
        psnr_values = []
        with torch.no_grad():
            for inputs, targets in loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, targets)
                total_loss += loss.item()
                psnr_values.append(self.calculate_psnr(outputs, targets))
        
        avg_loss = total_loss / len(loader)
        avg_psnr = np.mean(psnr_values)
        psnr_std = np.std(psnr_values)
        return avg_loss, avg_psnr, psnr_std

    def calculate_psnr(self, output, target):
        mse = torch.mean((output - target) ** 2)
        if mse == 0:
            return float('inf')
        max_pixel_value = torch.max(target) - torch.min(target)  # Updated data range calculation
        psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))
        return psnr.item()

    def save_psnr_results(self):
            os.makedirs(f'{self.results_root}/psnr', exist_ok=True)
            mask_status = "learned_mask" if self.learn_mask else "unlearned_mask"
            with open(f'{self.results_root}/psnr/{mask_status}_{self.drop_rate}.txt', 'w') as f:
                f.write(f'Train PSNR mean: {self.train_psnr_mean}\n')
                f.write(f'Train PSNR std: {self.train_psnr_std}\n')
                f.write(f'Test PSNR mean: {self.test_psnr_mean}\n')
                f.write(f'Test PSNR std: {self.test_psnr_std}\n')

    def test(self, test_loader):
        _, self.test_psnr_mean, self.test_psnr_std = self.evaluate(test_loader)
        self.save_images(test_loader)

    def save_images(self, test_loader):
        self.model.eval()
        freq, image = next(iter(test_loader))
        output = self.model(freq.to(self.device)).squeeze(1)

        os.makedirs(f'{self.results_root}/images', exist_ok=True)
        mask_status = "learned_mask" if self.learn_mask else "unlearned_mask"
        
        # Save output image
        plt.imshow(output[0].detach().cpu().numpy(), cmap='gray')
        plt.savefig(f'{self.results_root}/images/{mask_status}_output_{self.drop_rate}.png')
        plt.close()
        
        # Save target image
        plt.imshow(image[0].detach().cpu().numpy(), cmap='gray')
        plt.savefig(f'{self.results_root}/images/{mask_status}_true_{self.drop_rate}.png')
        plt.close()

    def plot_losses(self, i = 0):
        os.makedirs(f'{self.results_root}/graphs', exist_ok=True)
        mask_status = "learned_mask" if self.learn_mask else "unlearned_mask"
        
        plt.figure()
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'{self.results_root}/graphs/{mask_status}_loss_graph_{self.drop_rate}_{i}.png')
        plt.close()

In [4]:
# create class to mimic the argparser above
class Args:
    def __init__(self):
        self.seed = 0
        self.data_path = '/datasets/fastmri_knee/'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = 16
        self.num_workers = 1
        self.num_epochs = 6
        self.report_interval = 10
        self.drop_rate = 0.1
        self.learn_mask = False
        self.results_root = 'results'
        self.lr =  0.001
        self.mask_lr = 0.001
        self.val_test_split = 0.3

args = Args()

In [5]:
def create_datasets(args,resolution=320):
    '''This function creates the train and test datasets.
    You probably wouldn't need to change it'''
    
    train_data = SliceData(
        root=f"{args.data_path}/singlecoil_train",
        transform=DataTransform(resolution),
        split=1
    )
    dev_data = SliceData(
        root=f"{args.data_path}/singlecoil_val",
        transform=DataTransform(resolution),
        split = args.val_test_split,
        validation = True
    )
    test_data = SliceData(
        root=f"{args.data_path}/singlecoil_val",
        transform=DataTransform(resolution),
        split = args.val_test_split,
        validation = False
    )

    data_size = 50
    train_data.examples = train_data.examples[:data_size]
    dev_data.examples = dev_data.examples[:int((data_size*args.val_test_split)//4)]
    test_data.examples = test_data.examples[:int(data_size*(1-args.val_test_split)//4)]

    return train_data, dev_data, test_data


def create_data_loaders(args):
    '''Create train, validation and test datasets, and then out of them create the dataloaders. 
       These loaders will automatically apply needed transforms, as dictated in the create_datasets function using the transform parameter.'''
    train_data, dev_data, test_data = create_datasets(args)
    
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    return train_loader, dev_loader, test_loader

def freq_to_image(freq_data):
    ''' 
    This function accepts as input an image in the frequency domain, of size (B,320,320,2) (where B is batch size).
    Returns a tensor of size (B,320,320) representing the data in image domain.
    '''
    return transforms.complex_abs(transforms.ifft2_regular(freq_data))


train_loader, validation_loader, test_loader = create_data_loaders(args) #get dataloaders
print(len(train_loader.dataset))
print(len(validation_loader.dataset))
print(len(test_loader.dataset))

50
3
8


In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):  # Added dropout probability
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.01, inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.skip = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        identity = self.skip(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + identity)


class UNetModel(nn.Module):
    """ A U-Net model that uses residual blocks and includes a self-attention mechanism """
    def __init__(self, drop_rate, device, learn_mask, num_channels=32, pool_kernel_size=2):
        super().__init__()
        self.subsample = SubsamplingLayer(drop_rate, device, learn_mask)
        self.down1 = ResidualBlock(1, num_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.down2 = ResidualBlock(num_channels, num_channels * 2)
        self.pool2 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.bottleneck = ResidualBlock(num_channels * 2, num_channels * 4)
        self.up1 = nn.ConvTranspose2d(num_channels * 4, num_channels * 2, kernel_size=2, stride=2)
        self.up_block1 = ResidualBlock(num_channels * 4, num_channels * 2)
        self.up2 = nn.ConvTranspose2d(num_channels * 2, num_channels, kernel_size=2, stride=2)
        self.up_block2 = ResidualBlock(num_channels * 2, num_channels)
        self.final_conv = nn.Conv2d(num_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.subsample(x)
        #Downsampling
        x1 = self.down1(x)
        x = self.pool1(x1)
        x2 = self.down2(x)
        x = self.pool2(x2)

        # Bottelneck and attention
        x = self.bottleneck(x)
        
        # Upsampling
        x = self.up1(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_block1(x)
        x = self.up2(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_block2(x)
        
        return self.final_conv(x).squeeze(1)


In [8]:
model = UNetModel(args.drop_rate, args.device, args.learn_mask).to(args.device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
trainer = Trainer(model, optimizer, loss_fn, args.device, args.mask_lr, args.results_root, args.drop_rate)

print("Starting")
trainer.fit(train_loader, validation_loader, args.num_epochs, 0)
trainer.test(test_loader)
trainer.save_psnr_results()



Starting
Epopch: 0, Train Loss: 0.4953901022672653, Val Loss: 0.6425914168357849, Train PSNR: 21.094398975372314, Val PSNR: 17.198169708251953 
Epopch: 1, Train Loss: 0.23383933678269386, Val Loss: 0.4622328281402588, Train PSNR: 24.044400215148926, Val PSNR: 18.628910064697266 
Epopch: 2, Train Loss: 0.24694739654660225, Val Loss: 0.4198969006538391, Train PSNR: 23.055622577667236, Val PSNR: 19.046091079711914 
Epopch: 3, Train Loss: 0.2508646883070469, Val Loss: 0.26970407366752625, Train PSNR: 23.165512084960938, Val PSNR: 20.968643188476562 
Epopch: 4, Train Loss: 0.19921371713280678, Val Loss: 0.20227639377117157, Train PSNR: 24.62506866455078, Val PSNR: 22.21806526184082 
Epopch: 5, Train Loss: 0.22034508734941483, Val Loss: 0.18953582644462585, Train PSNR: 23.66673231124878, Val PSNR: 22.50060272216797 


In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):  # Added dropout probability
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.01, inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.skip = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        identity = self.skip(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + identity)


class UNetModel(nn.Module):
    """ A U-Net model that uses residual blocks and includes a self-attention mechanism """
    def __init__(self, drop_rate, device, learn_mask, num_channels=32, pool_kernel_size=2):
        super().__init__()
        self.subsample = SubsamplingLayer(drop_rate, device, learn_mask)
        self.down1 = ResidualBlock(1, num_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.down2 = ResidualBlock(num_channels, num_channels * 2)
        self.pool2 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.bottleneck = ResidualBlock(num_channels * 2, num_channels * 4)
        self.up1 = nn.ConvTranspose2d(num_channels * 4, num_channels * 2, kernel_size=2, stride=2)
        self.up_block1 = ResidualBlock(num_channels * 4, num_channels * 2)
        self.up2 = nn.ConvTranspose2d(num_channels * 2, num_channels, kernel_size=2, stride=2)
        self.up_block2 = ResidualBlock(num_channels * 2, num_channels)
        self.final_conv = nn.Conv2d(num_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.subsample(x)
        #Downsampling
        x1 = self.down1(x)
        x = self.pool1(x1)
        x2 = self.down2(x)
        x = self.pool2(x2)

        # Bottelneck and attention
        x = self.bottleneck(x)
        
        # Upsampling
        x = self.up1(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_block1(x)
        x = self.up2(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_block2(x)
        
        return self.final_conv(x).squeeze(1)


In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_chans, out_chans, drop_prob = 0):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob),
            nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_chans),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout2d(drop_prob)
        )

    def forward(self, x):
        return self.layers(x)

class UNetModel(nn.Module):
    """ A U-Net model that uses residual blocks and includes a self-attention mechanism """
    def __init__(self, drop_rate, device, learn_mask, num_channels=32, pool_kernel_size=2):
        super().__init__()
        self.subsample = SubsamplingLayer(drop_rate, device, learn_mask)
        self.down1 = ResidualBlock(1, num_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.down2 = ResidualBlock(num_channels, num_channels * 2)
        self.pool2 = nn.MaxPool2d(kernel_size=pool_kernel_size)
        self.bottleneck = ResidualBlock(num_channels * 2, num_channels * 4)
        self.up1 = nn.ConvTranspose2d(num_channels * 4, num_channels * 2, kernel_size=2, stride=2)
        self.up_block1 = ResidualBlock(num_channels * 4, num_channels * 2)
        self.up2 = nn.ConvTranspose2d(num_channels * 2, num_channels, kernel_size=2, stride=2)
        self.up_block2 = ResidualBlock(num_channels * 2, num_channels)
        self.final_conv = nn.Conv2d(num_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.subsample(x)
        #Downsampling
        x1 = self.down1(x)
        x = self.pool1(x1)
        x2 = self.down2(x)
        x = self.pool2(x2)

        # Bottelneck and attention
        x = self.bottleneck(x)
        
        # Upsampling
        x = self.up1(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_block1(x)
        x = self.up2(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_block2(x)
        
        return self.final_conv(x).squeeze(1)
