<a href="https://colab.research.google.com/github/RiverBotham/Raman/blob/main/Raman%20Imaging%20Denoising%20-%20TransUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TODO:


*   Add utilities to github
*   Update this notebook to clone repo
*   Add updates to this notbook to run a train & test for de-noising using images from google drive but utilities from github
*   Add in k-means & testing framework
*   Repeat with second notebook for hyper-spectral super sesolution



In [1]:
# To save forst clone the repo
!git config --global user.name "RiverBotham"
!git config --global user.email "river.botham@gmail.com"
!git config --global user.password "MY_PASSWORD"

token = 'MY_TOKEN'
username = 'RiverBotham'
repo = 'Raman'

!git clone https://{token}@github.com/{username}/{repo}

Cloning into 'Raman'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (101/101), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 101 (delta 56), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (101/101), 7.49 MiB | 11.29 MiB/s, done.
Resolving deltas: 100% (56/56), done.


In [2]:
# Move into the cloned repo, then File -> Save copy in GitHub
%cd {repo}/Denoising

/content/Raman/Denoising


In [3]:
# Imports
import os
import sys
import random
import datetime
import time
import shutil
import argparse
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import scipy.signal
import math
from skimage.metrics import structural_similarity as sk_ssim
from sklearn.model_selection import KFold

import torch
from torch import nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.utils.data.distributed
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp
from torchvision import transforms, utils
import pandas as pd
import matplotlib.pyplot as plt
from einops import rearrange, repeat

# import model, dataset, utilities

In [4]:
# model


class BasicConv(nn.Module):
    def __init__(self, channels_in, channels_out, batch_norm):
        super(BasicConv, self).__init__()
        basic_conv = [nn.Conv1d(channels_in, channels_out, kernel_size = 3, stride = 1, padding = 1, bias = True)]
        basic_conv.append(nn.PReLU())
        if batch_norm:
            basic_conv.append(nn.BatchNorm1d(channels_out))

        self.body = nn.Sequential(*basic_conv)

    def forward(self, x):
        return self.body(x)

class ResUNetConv(nn.Module):
    def __init__(self, num_convs, channels, batch_norm):
        super(ResUNetConv, self).__init__()
        unet_conv = []
        for _ in range(num_convs):
            unet_conv.append(nn.Conv1d(channels, channels, kernel_size = 3, stride = 1, padding = 1, bias = True))
            unet_conv.append(nn.PReLU())
            if batch_norm:
                unet_conv.append(nn.BatchNorm1d(channels))

        self.body = nn.Sequential(*unet_conv)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

class UNetLinear(nn.Module):
    def __init__(self, repeats, channels_in, channels_out):
        super().__init__()
        modules = []
        for i in range(repeats):
            modules.append(nn.Linear(channels_in, channels_out))
            modules.append(nn.PReLU())

        self.body = nn.Sequential(*modules)

    def forward(self, x):
        x = self.body(x)
        return x

class ResUNet(nn.Module):
    def __init__(self, num_convs, batch_norm):
        super(ResUNet, self).__init__()
        res_conv1 = [BasicConv(1, 64, batch_norm)]
        res_conv1.append(ResUNetConv(num_convs, 64, batch_norm))
        self.conv1 = nn.Sequential(*res_conv1)
        self.pool1 = nn.MaxPool1d(2)

        res_conv2 = [BasicConv(64, 128, batch_norm)]
        res_conv2.append(ResUNetConv(num_convs, 128, batch_norm))
        self.conv2 = nn.Sequential(*res_conv2)
        self.pool2 = nn.MaxPool1d(2)

        res_conv3 = [BasicConv(128, 256, batch_norm)]
        res_conv3.append(ResUNetConv(num_convs, 256, batch_norm))
        res_conv3.append(BasicConv(256, 128, batch_norm))
        self.conv3 = nn.Sequential(*res_conv3)
        self.up3 = nn.Upsample(scale_factor = 2)

        res_conv4 = [BasicConv(256, 128, batch_norm)]
        res_conv4.append(ResUNetConv(num_convs, 128, batch_norm))
        res_conv4.append(BasicConv(128, 64, batch_norm))
        self.conv4 = nn.Sequential(*res_conv4)
        self.up4 = nn.Upsample(scale_factor = 2)

        res_conv5 = [BasicConv(128, 64, batch_norm)]
        res_conv5.append(ResUNetConv(num_convs,64, batch_norm))
        self.conv5 = nn.Sequential(*res_conv5)
        res_conv6 = [BasicConv(64, 1, batch_norm)]
        self.conv6 = nn.Sequential(*res_conv6)

        self.linear7 = UNetLinear(3, 500, 500)

    def forward(self, x):
        x = self.conv1(x)
        x1 = self.pool1(x)

        x2 = self.conv2(x1)
        x3 = self.pool1(x2)

        x3 = self.conv3(x3)
        x3 = self.up3(x3)

        x4 = torch.cat((x2, x3), dim = 1)
        x4 = self.conv4(x4)
        x5 = self.up4(x4)

        x6 = torch.cat((x, x5), dim = 1)
        x6 = self.conv5(x6)
        x7 = self.conv6(x6)

        out = self.linear7(x7)

        return out

In [5]:
# 1D U-Net architecture for signal denoising

class UNet1D(nn.Module):
    def __init__(self):
        super(UNet1D, self).__init__()

        # Encoder
        self.encoder1 = self.conv_block(1, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = self.upconv(1024, 512)
        self.decoder4 = self.conv_block(1024, 512)  # Concatenation doubles the channels

        self.upconv3 = self.upconv(512, 256)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = self.upconv(256, 128)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = self.upconv(128, 64)
        self.decoder1 = self.conv_block(128, 64)

        # Final output
        self.conv_final = nn.Conv1d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose1d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoding path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool1d(enc1, 2))
        enc3 = self.encoder3(F.max_pool1d(enc2, 2))
        enc4 = self.encoder4(F.max_pool1d(enc3, 2))

        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool1d(enc4, 2))

        # Decoding path
        dec4 = self.upconv4(bottleneck)

        # Padding if necessary (for size mismatch)
        if dec4.size(2) != enc4.size(2):
            dec4 = F.pad(dec4, (0, enc4.size(2) - dec4.size(2)))

        dec4 = torch.cat((enc4, dec4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        if dec3.size(2) != enc3.size(2):
            dec3 = F.pad(dec3, (0, enc3.size(2) - dec3.size(2)))

        dec3 = torch.cat((enc3, dec3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        if dec2.size(2) != enc2.size(2):
            dec2 = F.pad(dec2, (0, enc2.size(2) - dec2.size(2)))

        dec2 = torch.cat((enc2, dec2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        if dec1.size(2) != enc1.size(2):
            dec1 = F.pad(dec1, (0, enc1.size(2) - dec1.size(2)))

        dec1 = torch.cat((enc1, dec1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv_final(dec1)

In [6]:

class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv1d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(middle_channels)
        self.conv2 = nn.Conv1d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out


class UNetPlusPlus1D(nn.Module):
    def __init__(self, input_channels=1, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool1d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv1d(nb_filter[0], 1, kernel_size=1)
            self.final2 = nn.Conv1d(nb_filter[0], 1, kernel_size=1)
            self.final3 = nn.Conv1d(nb_filter[0], 1, kernel_size=1)
            self.final4 = nn.Conv1d(nb_filter[0], 1, kernel_size=1)
        else:
            self.final = nn.Conv1d(nb_filter[0], 1, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))

                # Pad self.up(x3_1) to match x2_0 spatial dimensions
        upsampled_x3_0 = self.up(x3_0)
        padded_upsampled_x3_0 = F.pad(upsampled_x3_0, (0, x2_0.size(2) - upsampled_x3_0.size(2)))

        x2_1 = self.conv2_1(torch.cat([x2_0, padded_upsampled_x3_0], 1)) # Concatenate with padded tensor
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))

        padded_x2_1 = F.pad(x2_1, (0, x2_0.size(2) - x2_1.size(2)))
        upsampled_x3_1 = self.up(x3_1)
        padded_upsampled_x3_1 = F.pad(upsampled_x3_1, (0, x2_0.size(2) - upsampled_x3_1.size(2)))

        x2_2 = self.conv2_2(torch.cat([x2_0, padded_x2_1, padded_upsampled_x3_1], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            return output

In [7]:
class ConvBlockA(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv = nn.Sequential(
                                  nn.Conv1d(ch_in, ch_out,
                                            kernel_size=3, stride=1,
                                            padding=1, bias=True),
                                  nn.BatchNorm1d(ch_out),
                                  nn.ReLU(inplace=True),
                                  nn.Conv1d(ch_out, ch_out,
                                            kernel_size=3, stride=1,
                                            padding=1, bias=True),
                                  nn.BatchNorm1d(ch_out),
                                  nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class UpConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.up = nn.Sequential(
                                nn.Upsample(scale_factor=2),
                                nn.Conv1d(ch_in, ch_out,
                                         kernel_size=3,stride=1,
                                         padding=1, bias=True),
                                nn.BatchNorm1d(ch_out),
                                nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = x = self.up(x)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, f_g, f_l, f_int):
        super().__init__()

        self.w_g = nn.Sequential(
                                nn.Conv1d(f_g, f_int,
                                         kernel_size=1, stride=1,
                                         padding=0, bias=True),
                                nn.BatchNorm1d(f_int)
        )

        self.w_x = nn.Sequential(
                                nn.Conv1d(f_l, f_int,
                                         kernel_size=1, stride=1,
                                         padding=0, bias=True),
                                nn.BatchNorm1d(f_int)
        )

        self.psi = nn.Sequential(
                                nn.Conv1d(f_int, 1,
                                         kernel_size=1, stride=1,
                                         padding=0,  bias=True),
                                nn.BatchNorm1d(1),
                                nn.Sigmoid(),
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return psi*x

class AttentionUNet(nn.Module):
    def __init__(self, n_classes=1, in_channel=1, out_channel=1):
        super().__init__()

        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv1 = ConvBlockA(ch_in=in_channel, ch_out=64)
        self.conv2 = ConvBlockA(ch_in=64, ch_out=128)
        self.conv3 = ConvBlockA(ch_in=128, ch_out=256)
        self.conv4 = ConvBlockA(ch_in=256, ch_out=512)
        self.conv5 = ConvBlockA(ch_in=512, ch_out=1024)

        self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
        self.att5 = AttentionBlock(f_g=512, f_l=512, f_int=256)
        self.upconv5 = ConvBlockA(ch_in=1024, ch_out=512)

        self.up4 = UpConvBlock(ch_in=512, ch_out=256)
        self.att4 = AttentionBlock(f_g=256, f_l=256, f_int=128)
        self.upconv4 = ConvBlockA(ch_in=512, ch_out=256)

        self.up3 = UpConvBlock(ch_in=256, ch_out=128)
        self.att3 = AttentionBlock(f_g=128, f_l=128, f_int=64)
        self.upconv3 = ConvBlockA(ch_in=256, ch_out=128)

        self.up2 = UpConvBlock(ch_in=128, ch_out=64)
        self.att2 = AttentionBlock(f_g=64, f_l=64, f_int=32)
        self.upconv2 = ConvBlockA(ch_in=128, ch_out=64)

        self.conv_1x1 = nn.Conv1d(64, out_channel,
                                  kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # encoder
        x1 = self.conv1(x)

        x2 = self.maxpool(x1)
        x2 = self.conv2(x2)

        x3 = self.maxpool(x2)
        x3 = self.conv3(x3)

        x4 = self.maxpool(x3)
        x4 = self.conv4(x4)

        x5 = self.maxpool(x4)
        x5 = self.conv5(x5)

        # decoder + concat
        d5 = self.up5(x5)
        x4 = self.att5(g=d5, x=x4)
        d5 = torch.concat((x4, d5), dim=1)
        d5 = self.upconv5(d5)

        d4 = self.up4(d5)
        d4 = F.pad(d4, (0, x3.size(2) - d4.size(2)))
        x3 = self.att4(g=d4, x=x3)
        d4 = torch.concat((x3, d4), dim=1)
        d4 = self.upconv4(d4)

        d3 = self.up3(d4)
        x2 = self.att3(g=d3, x=x2)
        d3 = torch.concat((x2, d3), dim=1)
        d3 = self.upconv3(d3)

        d2 = self.up2(d3)
        x1 = self.att2(g=d2, x=x1)
        d2 = torch.concat((x1, d2), dim=1)
        d2 = self.upconv2(d2)

        d1 = self.conv_1x1(d2)

        return d1

In [64]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, head_num):
        super().__init__()

        self.head_num = head_num
        self.dk = (embedding_dim // head_num) ** (1 / 2)

        self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
        self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)

    def forward(self, x, mask=None):
        qkv = self.qkv_layer(x)

        query, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.head_num))
        energy = torch.einsum("... i d, ... j d -> ... i j", query, key) * self.dk

        if mask is not None:
            energy = energy.masked_fill(mask, -float('inf'))

        attention = torch.softmax(energy, dim=-1)

        x = torch.einsum("... i j, ... j d -> ... i d", attention, value)

        x = rearrange(x, "b h t d -> b t (h d)")
        x = self.out_attention(x)

        return x


class MLP(nn.Module):
    def __init__(self, embedding_dim, mlp_dim):
        super().__init__()

        self.mlp_layers = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(mlp_dim, embedding_dim),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        return self.mlp_layers(x)


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dim, head_num, mlp_dim):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
        self.mlp = MLP(embedding_dim, mlp_dim)

        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        _x = self.multi_head_attention(x)
        _x = self.dropout(_x)
        x = x + _x
        x = self.layer_norm1(x)

        _x = self.mlp(x)
        x = x + _x
        x = self.layer_norm2(x)

        return x


class Transformer1D(nn.Module):
    def __init__(self, input_dim, embedding_dim, head_num, mlp_dim, block_num):
        super().__init__()

        self.embedding = nn.Linear(input_dim, embedding_dim)  # Embedding for 1D input
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)
        ])

    def forward(self, x):
        # x shape: (batch_size, seq_length, input_dim)
        x = self.embedding(x)  # Shape: (batch_size, seq_length, embedding_dim)
        for block in self.transformer_blocks:
            x = block(x)
        return x


class EncoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, base_width=64):
        super().__init__()

        self.downsample = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm1d(out_channels)
        )

        width = int(out_channels * (base_width / 64))

        self.conv1 = nn.Conv1d(in_channels, width, kernel_size=1, stride=1, bias=False)
        self.norm1 = nn.BatchNorm1d(width)

        self.conv2 = nn.Conv1d(width, width, kernel_size=3, stride=2, padding=1, bias=False)
        self.norm2 = nn.BatchNorm1d(width)

        self.conv3 = nn.Conv1d(width, out_channels, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm1d(out_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_down = self.downsample(x)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x = x + x_down
        x = self.relu(x)

        return x


class DecoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='linear', align_corners=True)
        self.layer = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, x_concat=None):
        x = self.upsample(x)

        if x_concat is not None:
            x = torch.cat([x_concat, x], dim=1)

        x = self.layer(x)
        return x


class Encoder(nn.Module):
    def __init__(self, signal_length, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)
        self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)
        self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)

        self.vit_signal_length = signal_length // patch_dim
        self.vit = Transformer1D(self.vit_signal_length, out_channels * 8, head_num, mlp_dim, block_num)

        self.conv2 = nn.Conv1d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm1d(512)

    def forward(self, x):
        x = self.conv1(x)  # Shape: (batch_size, out_channels, reduced_length)
        x = self.norm1(x)
        x1 = self.relu(x)

        x2 = self.encoder1(x1)
        x3 = self.encoder2(x2)
        x = self.encoder3(x3)

        # Reshape for transformer input
        x = rearrange(x, 'b c l -> b l c')  # Change to (batch_size, length, channels)
        x = self.vit(x)  # Pass through the transformer
        x = rearrange(x, 'b l c -> b c l')  # Change back to (batch_size, channels, length)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)

        return x, x1, x2, x3


class Decoder(nn.Module):
    def __init__(self, out_channels, class_num):
        super().__init__()

        self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 4)
        self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels * 2)
        self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))
        self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))

        self.conv1 = nn.Conv1d(int(out_channels * 1 / 8), class_num, kernel_size=1)

    def forward(self, x, x1, x2, x3):
        x = self.decoder1(x, x3)
        x = self.decoder2(x, x2)
        x = self.decoder3(x, x1)
        x = self.decoder4(x)
        x = self.conv1(x)

        return x


class TransUNet(nn.Module):
    def __init__(self, signal_length, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):
        super().__init__()

        self.encoder = Encoder(signal_length, in_channels, out_channels,
                               head_num, mlp_dim, block_num, patch_dim)

        self.decoder = Decoder(out_channels, class_num)

    def forward(self, x):
        x, x1, x2, x3 = self.encoder(x)
        x = self.decoder(x, x1, x2, x3)

        return x

In [31]:
# # Trans U net

# import torch
# import torch.nn as nn
# import numpy as np
# from einops import rearrange, repeat

# # Multi-Head Attention for 1D
# class MultiHeadAttention(nn.Module):
#     def __init__(self, embedding_dim, head_num):
#         super().__init__()

#         self.head_num = head_num
#         self.dk = (embedding_dim // head_num) ** (1 / 2)

#         self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
#         self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)

#     def forward(self, x, mask=None):
#         qkv = self.qkv_layer(x)

#         query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num))
#         energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk

#         if mask is not None:
#             energy = energy.masked_fill(mask, -np.inf)

#         attention = torch.softmax(energy, dim=-1)

#         x = torch.einsum("... i j , ... j d -> ... i d", attention, value)

#         x = rearrange(x, "b h t d -> b t (h d)")
#         x = self.out_attention(x)

#         return x


# class MLP(nn.Module):
#     def __init__(self, embedding_dim, mlp_dim):
#         super().__init__()

#         self.mlp_layers = nn.Sequential(
#             nn.Linear(embedding_dim, mlp_dim),
#             nn.GELU(),
#             nn.Dropout(0.1),
#             nn.Linear(mlp_dim, embedding_dim),
#             nn.Dropout(0.1)
#         )

#     def forward(self, x):
#         x = self.mlp_layers(x)

#         return x


# class TransformerEncoderBlock(nn.Module):
#     def __init__(self, embedding_dim, head_num, mlp_dim):
#         super().__init__()

#         self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
#         self.mlp = MLP(embedding_dim, mlp_dim)

#         self.layer_norm1 = nn.LayerNorm(embedding_dim)
#         self.layer_norm2 = nn.LayerNorm(embedding_dim)

#         self.dropout = nn.Dropout(0.1)

#     def forward(self, x):
#         _x = self.multi_head_attention(x)
#         _x = self.dropout(_x)
#         x = x + _x
#         x = self.layer_norm1(x)

#         _x = self.mlp(x)
#         x = x + _x
#         x = self.layer_norm2(x)

#         return x


# class TransformerEncoder(nn.Module):
#     def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):
#         super().__init__()

#         self.layer_blocks = nn.ModuleList(
#             [TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)])

#     def forward(self, x):
#         for layer_block in self.layer_blocks:
#             x = layer_block(x)

#         return x


# class ViT(nn.Module):
#     def __init__(self, seq_len, in_channels, embedding_dim, head_num, mlp_dim,
#                  block_num, patch_dim, classification=False, num_classes=1):
#         super().__init__()

#         self.patch_dim = patch_dim
#         self.classification = classification
#         self.num_tokens = (seq_len // patch_dim) ** 2
#         self.token_dim = in_channels * (patch_dim ** 2)

#         self.projection = nn.Linear(self.token_dim, embedding_dim)
#         self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))

#         self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

#         self.dropout = nn.Dropout(0.1)

#         self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)

#         if self.classification:
#             self.mlp_head = nn.Linear(embedding_dim, num_classes)

#     def forward(self, x):
#         seq_patches = rearrange(x,
#                                 'b c (patch s) -> b (s) (patch c)',
#                                 patch=self.patch_dim)

#         batch_size, tokens, _ = seq_patches.shape

#         project = self.projection(seq_patches)
#         token = repeat(self.cls_token, 'b ... -> (b batch_size) ...',
#                        batch_size=batch_size)

#         patches = torch.cat([token, project], dim=1)
#         patches += self.embedding[:tokens + 1, :]

#         x = self.dropout(patches)
#         x = self.transformer(x)
#         x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]

#         return x



# class EncoderBottleneck(nn.Module):
#     def __init__(self, in_channels, out_channels, stride=1, base_width=64):
#         super().__init__()

#         self.downsample = nn.Sequential(
#             nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
#             nn.BatchNorm1d(out_channels)
#         )

#         width = int(out_channels * (base_width / 64))

#         self.conv1 = nn.Conv1d(in_channels, width, kernel_size=1, stride=1, bias=False)
#         self.norm1 = nn.BatchNorm1d(width)

#         self.conv2 = nn.Conv1d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)
#         self.norm2 = nn.BatchNorm1d(width)

#         self.conv3 = nn.Conv1d(width, out_channels, kernel_size=1, stride=1, bias=False)
#         self.norm3 = nn.BatchNorm1d(out_channels)

#         self.relu = nn.ReLU(inplace=True)

#     def forward(self, x):
#         x_down = self.downsample(x)

#         x = self.conv1(x)
#         x = self.norm1(x)
#         x = self.relu(x)

#         x = self.conv2(x)
#         x = self.norm2(x)
#         x = self.relu(x)

#         x = self.conv3(x)
#         x = self.norm3(x)
#         x = x + x_down
#         x = self.relu(x)

#         return x


# class DecoderBottleneck(nn.Module):
#     def __init__(self, in_channels, out_channels, scale_factor=2):
#         super().__init__()

#         self.upsample = nn.Upsample(scale_factor=scale_factor, mode='linear', align_corners=True)
#         self.layer = nn.Sequential(
#             nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm1d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm1d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x, x_concat=None):
#         x = self.upsample(x)

#         if x_concat is not None:
#             print(f"x_concat shape: {x_concat.shape}")
#             print(f"x shape: {x.shape}")
#             x = torch.cat([x_concat, x], dim=1)

#         x = self.layer(x)
#         return x


# class Encoder(nn.Module):
#     def __init__(self, seq_len, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
#         super().__init__()

#         self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)
#         self.norm1 = nn.BatchNorm1d(out_channels)
#         self.relu = nn.ReLU(inplace=True)

#         self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)
#         self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)
#         self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)

#         self.vit_seq_len = seq_len // patch_dim
#         self.vit = ViT(self.vit_seq_len, out_channels * 8, out_channels * 8,
#                        head_num, mlp_dim, block_num, patch_dim=1, classification=False)

#         self.conv2 = nn.Conv1d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)
#         self.norm2 = nn.BatchNorm1d(512)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.norm1(x)
#         x1 = self.relu(x)

#         x2 = self.encoder1(x1)
#         x3 = self.encoder2(x2)
#         x = self.encoder3(x3)

#         x = self.vit(x)
#         x = rearrange(x, "b c l -> b l c")

#         x = self.conv2(x)
#         x = self.norm2(x)
#         x = self.relu(x)

#         return x, x1, x2, x3


# class Decoder(nn.Module):
#     def __init__(self, out_channels, class_num):
#         super().__init__()

#         self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2)
#         self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)
#         self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))
#         self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))

#         self.conv1 = nn.Conv1d(int(out_channels * 1 / 8), class_num, kernel_size=1)

#     def forward(self, x, x1, x2, x3):
#         print(f"Initial x1 shape: {x1.shape}, x2 shape: {x2.shape}")
#         print(f"Initial x shape: {x.shape}, x3 shape: {x3.shape}")
#         x = self.decoder1(x, x3)
#         print(f"After decoder 1 x shape: {x.shape}, x2 shape: {x2.shape}")
#         x = self.decoder2(x, x2)
#         print(f"After decoder 2 x shape: {x.shape}, x1 shape: {x1.shape}")
#         x = self.decoder3(x, x1)
#         print(f"After decoder 3 x shape: {x.shape}")
#         x = self.decoder4(x)
#         print(f"Final decoder output shape before conv1: {x.shape}")
#         x = self.conv1(x)

#         return x


# class TransUNet(nn.Module):
#     def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):
#         super().__init__()

#         self.encoder = Encoder(img_dim, in_channels, out_channels,
#                                head_num, mlp_dim, block_num, patch_dim)

#         self.decoder = Decoder(out_channels, class_num)

#     def forward(self, x):
#         x, x1, x2, x3 = self.encoder(x)
#         x = self.decoder(x, x1, x2, x3)

#         return x

In [48]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from einops import rearrange, repeat

# # Multi-Head Attention for 1D
# class MultiHeadAttention(nn.Module):
#     def __init__(self, embedding_dim, head_num):
#         super().__init__()
#         self.head_num = head_num
#         self.dk = (embedding_dim // head_num) ** (1 / 2)
#         self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
#         self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)

#     def forward(self, x, mask=None):
#         qkv = self.qkv_layer(x)
#         query, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.head_num))
#         energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk

#         if mask is not None:
#             energy = energy.masked_fill(mask, float('-inf'))

#         attention = torch.softmax(energy, dim=-1)
#         x = torch.einsum("... i j , ... j d -> ... i d", attention, value)
#         x = rearrange(x, "b h t d -> b t (h d)")
#         x = self.out_attention(x)

#         return x

# # MLP for Feedforward
# class MLP(nn.Module):
#     def __init__(self, embedding_dim, mlp_dim):
#         super().__init__()
#         self.mlp_layers = nn.Sequential(
#             nn.Linear(embedding_dim, mlp_dim),
#             nn.GELU(),
#             nn.Dropout(0.1),
#             nn.Linear(mlp_dim, embedding_dim),
#             nn.Dropout(0.1)
#         )

#     def forward(self, x):
#         return self.mlp_layers(x)

# # Transformer Encoder Block
# class TransformerEncoderBlock(nn.Module):
#     def __init__(self, embedding_dim, head_num, mlp_dim):
#         super().__init__()
#         self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
#         self.mlp = MLP(embedding_dim, mlp_dim)
#         self.layer_norm1 = nn.LayerNorm(embedding_dim)
#         self.layer_norm2 = nn.LayerNorm(embedding_dim)
#         self.dropout = nn.Dropout(0.1)

#     def forward(self, x):
#         _x = self.multi_head_attention(x)
#         _x = self.dropout(_x)
#         x = x + _x
#         x = self.layer_norm1(x)
#         _x = self.mlp(x)
#         x = x + _x
#         x = self.layer_norm2(x)
#         return x

# # Transformer Encoder
# class TransformerEncoder(nn.Module):
#     def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):
#         super().__init__()
#         self.layer_blocks = nn.ModuleList(
#             [TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)]
#         )

#     def forward(self, x):
#         for layer_block in self.layer_blocks:
#             x = layer_block(x)
#         return x

# # Vision Transformer
# class ViT(nn.Module):
#     def __init__(self, seq_len, in_channels, embedding_dim, head_num, mlp_dim, block_num, patch_dim):
#         super().__init__()
#         self.patch_dim = patch_dim
#         self.num_tokens = (seq_len // patch_dim)
#         self.token_dim = in_channels * (patch_dim)

#         self.projection = nn.Linear(self.token_dim, embedding_dim)
#         self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))
#         self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
#         self.dropout = nn.Dropout(0.1)
#         self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)

#     def forward(self, x):
#         seq_patches = rearrange(x, 'b c (patch s) -> b (s) (patch c)', patch=self.patch_dim)
#         batch_size, tokens, _ = seq_patches.shape
#         project = self.projection(seq_patches)
#         token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', batch_size=batch_size)

#         patches = torch.cat([token, project], dim=1)
#         patches += self.embedding[:tokens + 1, :]
#         x = self.dropout(patches)
#         x = self.transformer(x)
#         return x[:, 1:, :]  # Return the patch tokens only

# # Encoder Bottleneck
# class EncoderBottleneck(nn.Module):
#     def __init__(self, in_channels, out_channels, stride=1):
#         super().__init__()
#         self.downsample = nn.Sequential(
#             nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
#             nn.BatchNorm1d(out_channels)
#         )

#         self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
#         self.norm1 = nn.BatchNorm1d(out_channels)

#         self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
#         self.norm2 = nn.BatchNorm1d(out_channels)

#         self.relu = nn.ReLU(inplace=True)

#     def forward(self, x):
#         x_down = self.downsample(x)

#         x = self.conv1(x)
#         x = self.norm1(x)
#         x = self.relu(x)

#         x = self.conv2(x)
#         x = self.norm2(x)
#         x = x + x_down
#         x = self.relu(x)

#         return x

# # Encoder
# class Encoder(nn.Module):
#     def __init__(self, spectrum_len, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
#         super().__init__()
#         self.bottleneck1 = EncoderBottleneck(in_channels, out_channels)        # Output shape will be [batch_size, out_channels, spectrum_len / 2]
#         self.bottleneck2 = EncoderBottleneck(out_channels, out_channels * 2)  # Output shape will be [batch_size, out_channels * 2, spectrum_len / 4]
#         self.bottleneck3 = EncoderBottleneck(out_channels * 2, out_channels * 4)  # Output shape will be [batch_size, out_channels * 4, spectrum_len / 8]

#         self.transformer = ViT(spectrum_len // 8, out_channels * 4, out_channels * 4, head_num, mlp_dim, block_num, patch_dim)  # Adjusting input dimensions for ViT

#     def forward(self, x):
#         print(f"Input shape: {x.shape}")

#         x1 = self.bottleneck1(x)
#         print(f"After bottleneck1 shape: {x1.shape}")

#         x2 = self.bottleneck2(x1)
#         print(f"After bottleneck2 shape: {x2.shape}")

#         x3 = self.bottleneck3(x2)
#         print(f"After bottleneck3 shape: {x3.shape}")

#         # Preparing for ViT
#         x3_reshaped = rearrange(x3, 'b c t -> b t c')  # Reshape for ViT input
#         print(f"Reshaped x3 for ViT: {x3_reshaped.shape}")

#         x3_transformed = self.transformer(x3_reshaped)  # Pass through the ViT
#         print(f"Output shape from ViT: {x3_transformed.shape}")

#         return x3_transformed, x1, x2, x3

# # Decoder Bottleneck
# class DecoderBottleneck(nn.Module):
#     def __init__(self, in_channels, out_channels, scale_factor=2):
#         super().__init__()
#         self.upsample = nn.Upsample(scale_factor=scale_factor, mode='linear', align_corners=True)
#         self.layer = nn.Sequential(
#             nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm1d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm1d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x, x_concat=None):
#         x = self.upsample(x)
#         if x_concat is not None:
#             # Ensure sizes match before concatenation
#             if x_concat.size(2) != x.size(2):
#                 diff = x_concat.size(2) - x.size(2)
#                 x_concat = F.pad(x_concat, (0, -diff, 0, 0))  # Pad x_concat
#             x = torch.cat((x, x_concat), dim=1)
#         return self.layer(x)

# # Decoder
# class Decoder(nn.Module):
#     def __init__(self, out_channels, num_classes):
#         super().__init__()
#         self.decoder3 = DecoderBottleneck(out_channels * 4, out_channels * 2)
#         self.decoder2 = DecoderBottleneck(out_channels * 2 * 2, out_channels)  # *2 for concatenation
#         self.decoder1 = DecoderBottleneck(out_channels * 2, out_channels)  # *2 for concatenation
#         self.final_conv = nn.Conv1d(out_channels, num_classes, kernel_size=1)

#     def forward(self, x, x1, x2, x3):
#         print(f"Input to decoder: {x.shape}")
#         x = self.decoder3(x, x3)
#         print(f"After decoder3 shape: {x.shape}")
#         x = self.decoder2(x, x2)
#         print(f"After decoder2 shape: {x.shape}")
#         x = self.decoder1(x, x1)
#         print(f"After decoder1 shape: {x.shape}")
#         return self.final_conv(x)

# # Full TransUNet Model
# class TransUNet(nn.Module):
#     def __init__(self, spectrum_len, in_channels, out_channels, head_num, mlp_dim, block_num, num_classes, patch_dim):
#         super().__init__()
#         self.encoder = Encoder(spectrum_len, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim)
#         self.decoder = Decoder(out_channels, num_classes)

#     def forward(self, x):
#         x, x1, x2, x3 = self.encoder(x)
#         return self.decoder(x, x1, x2, x3)

In [59]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class MultiHeadSelfAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(MultiHeadSelfAttention, self).__init__()
#         self.num_heads = num_heads
#         self.head_dim = embed_dim // num_heads

#         assert (
#             self.head_dim * num_heads == embed_dim
#         ), "Embedding dimension must be divisible by number of heads"

#         self.q_linear = nn.Linear(embed_dim, embed_dim)
#         self.k_linear = nn.Linear(embed_dim, embed_dim)
#         self.v_linear = nn.Linear(embed_dim, embed_dim)
#         self.fc_out = nn.Linear(embed_dim, embed_dim)

#     def forward(self, x):
#         N, seq_length, embed_dim = x.shape
#         q = self.q_linear(x).view(N, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, seq_length, head_dim)
#         k = self.k_linear(x).view(N, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, seq_length, head_dim)
#         v = self.v_linear(x).view(N, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, seq_length, head_dim)

#         # Calculate attention scores
#         energy = torch.einsum("nqhd,nkhd->nqkhd", [q, k])  # (N, num_heads, seq_length, seq_length)
#         attention = F.softmax(energy / (self.head_dim ** 0.5), dim=3)

#         out = torch.einsum("nqkhd,nkhd->nqhd", [attention, v]).reshape(N, seq_length, embed_dim)
#         return self.fc_out(out)

# class EncoderBlock(nn.Module):
#     def __init__(self, in_channels, embed_dim):
#         super(EncoderBlock, self).__init__()
#         self.conv = nn.Conv1d(in_channels, embed_dim, kernel_size=3, padding=1)
#         self.attn = MultiHeadSelfAttention(embed_dim, num_heads=4)
#         self.pool = nn.MaxPool1d(2)

#     def forward(self, x):
#         x = F.relu(self.conv(x))  # Convolution layer
#         x = x.permute(0, 2, 1)  # Change to (B, L, C) for attention
#         x = self.attn(x)  # Apply Multi-Head Self-Attention
#         x = x.permute(0, 2, 1)  # Change back to (B, C, L)
#         skip = x  # Skip connection
#         x = self.pool(x)
#         return x, skip

# class DecoderBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(DecoderBlock, self).__init__()
#         self.upconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=2, stride=2)
#         self.conv = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)

#     def forward(self, x, skip):
#         x = self.upconv(x)

#         # Adjust the size of skip if necessary
#         if x.size(1) != skip.size(1):  # Check number of channels
#             # Create a Conv1d layer with the same dtype as skip
#             conv_adjust = nn.Conv1d(skip.size(1), x.size(1), kernel_size=1).to(skip.dtype)
#             skip = conv_adjust(skip)  # Adjust channels with 1x1 convolution

#         if x.size(2) != skip.size(2):  # Check length
#             skip = skip[:, :, :x.size(2)]  # Trim the skip connection if necessary

#         x = x + skip  # Skip connection
#         x = F.relu(self.conv(x))
#         return x

# class TransUNetBasic(nn.Module):
#     def __init__(self, in_channels=1, out_channels=1):
#         super(TransUNetBasic, self).__init__()
#         self.encoder1 = EncoderBlock(in_channels, 64)
#         self.encoder2 = EncoderBlock(64, 128)
#         self.encoder3 = EncoderBlock(128, 256)

#         self.decoder3 = DecoderBlock(256, 128)
#         self.decoder2 = DecoderBlock(128, 64)
#         self.decoder1 = DecoderBlock(64, out_channels)

#     def forward(self, x):
#         enc1, skip1 = self.encoder1(x)
#         enc2, skip2 = self.encoder2(enc1)
#         enc3, skip3 = self.encoder3(enc2)

#         dec3 = self.decoder3(enc3, skip3)
#         dec2 = self.decoder2(dec3, skip2)
#         dec1 = self.decoder1(dec2, skip1)

#         return dec1

In [9]:
# data set

class RamanDataset(Dataset):
    def __init__(self, inputs, outputs, batch_size=64,spectrum_len=500, spectrum_shift=0.,
                 spectrum_window=False, horizontal_flip=False, mixup=False):
        self.inputs = inputs
        self.outputs = outputs
        self.batch_size = batch_size
        self.spectrum_len = spectrum_len
        self.spectrum_shift = spectrum_shift
        self.spectrum_window = spectrum_window
        self.horizontal_flip = horizontal_flip
        self.mixup = mixup
        self.on_epoch_end()

    def pad_spectrum(self, input_spectrum, spectrum_length):
        if len(input_spectrum) == spectrum_length:
            padded_spectrum = input_spectrum
        elif len(input_spectrum) > spectrum_length:
            padded_spectrum = input_spectrum[0:spectrum_length]
        else:
            padded_spectrum = np.pad(input_spectrum, ((0,spectrum_length - len(input_spectrum)),(0,0)), 'reflect')

        return padded_spectrum

    def window_spectrum(self, input_spectrum, start_idx, window_length):
        if len(input_spectrum) <= window_length:
            output_spectrum = input_spectrum
        else:
            end_idx = start_idx + window_length
            output_spectrum = input_spectrum[start_idx:end_idx]

        return output_spectrum

    def flip_axis(self, x, axis):
        if np.random.random() < 0.5:
            x = np.asarray(x).swapaxes(axis, 0)
            x = x[::-1, ...]
            x = x.swapaxes(0, axis)
        return x

    def shift_spectrum(self, x, shift_range):
        x = np.expand_dims(x,axis=-1)
        shifted_spectrum = x
        spectrum_shift_range = int(np.round(shift_range*len(x)))
        if spectrum_shift_range > 0:
            shifted_spectrum = np.pad(x[spectrum_shift_range:,:], ((0,abs(spectrum_shift_range)), (0,0)), 'reflect')
        elif spectrum_shift_range < 0:
            shifted_spectrum = np.pad(x[:spectrum_shift_range,:], ((abs(spectrum_shift_range), 0), (0,0)), 'reflect')
        return shifted_spectrum

    def mixup_spectrum(self, input_spectrum1, input_spectrum2, output_spectrum1, output_spectrum2, alpha):
        lam = np.random.beta(alpha, alpha)
        input_spectrum = (lam * input_spectrum1) + ((1 - lam) * input_spectrum2)
        output_spectrum = (lam * output_spectrum1) + ((1 - lam) * output_spectrum2)
        return input_spectrum, output_spectrum

    def __getitem__(self, index):
        input_spectrum = self.inputs[index]
        output_spectrum = self.outputs[index]

        mixup_on = False
        if self.mixup:
            if np.random.random() < 0.5:
                spectrum_idx = int(np.round(np.random.random() * (len(self.inputs)-1)))
                input_spectrum2 = self.inputs[spectrum_idx]
                output_spectrum2 = self.outputs[spectrum_idx]
                mixup_on = True

        if self.spectrum_window:
            start_idx = int(np.floor(np.random.random() * (len(input_spectrum)-self.spectrum_len)))
            input_spectrum = self.window_spectrum(input_spectrum, start_idx, self.spectrum_len)
            output_spectrum = self.window_spectrum(output_spectrum, start_idx, self.spectrum_len)
            if mixup_on:
                input_spectrum2 = self.window_spectrum(input_spectrum2, start_idx, self.spectrum_len)
                output_spectrum2 = self.window_spectrum(output_spectrum2, start_idx, self.spectrum_len)

        input_spectrum = self.pad_spectrum(input_spectrum, self.spectrum_len)
        output_spectrum = self.pad_spectrum(output_spectrum, self.spectrum_len)
        if mixup_on:
            input_spectrum2 = self.pad_spectrum(input_spectrum2, self.spectrum_len)
            output_spectrum2 = self.pad_spectrum(output_spectrum2, self.spectrum_len)

        if self.spectrum_shift != 0.0:
            shift_range = np.random.uniform(-self.spectrum_shift, self.spectrum_shift)
            input_spectrum = self.shift_spectrum(input_spectrum, shift_range)
            output_spectrum = self.shift_spectrum(output_spectrum, shift_range)
            if mixup_on:
                input_spectrum2 = self.shift_spectrum(input_spectrum2, shift_range)
                output_spectrum2 = self.shift_spectrum(output_spectrum2, shift_range)
        else:
            input_spectrum = np.expand_dims(input_spectrum, axis=-1)
            output_spectrum = np.expand_dims(output_spectrum, axis=-1)
            if mixup_on:
                input_spectrum2 = np.expand_dims(input_spectrum2, axis=-1)
                output_spectrum2 = np.expand_dims(output_spectrum2, axis=-1)

        if self.horizontal_flip:
            if np.random.random() < 0.5:
                input_spectrum = self.flip_axis(input_spectrum, 0)
                output_spectrum = self.flip_axis(output_spectrum, 0)
                if mixup_on:
                    input_spectrum2 = self.flip_axis(input_spectrum2, 0)
                    output_spectrum2 = self.flip_axis(output_spectrum2, 0)

        if mixup_on:
            input_spectrum, output_spectrum = self.mixup_spectrum(input_spectrum, input_spectrum2, output_spectrum, output_spectrum2, 0.2)

        input_spectrum = input_spectrum/np.amax(input_spectrum)
        output_spectrum = output_spectrum/np.amax(output_spectrum)

        input_spectrum = np.moveaxis(input_spectrum, -1, 0)
        output_spectrum = np.moveaxis(output_spectrum, -1, 0)

        sample = {'input_spectrum': input_spectrum, 'output_spectrum': output_spectrum}

        return sample

    def on_epoch_end(self):
        pass

    def __len__(self):
        return len(self.inputs)

In [10]:
# utilities

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [11]:
def train(dataloader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args):

    # For measuring time and losses
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss MSE', ':.4e')
    progress = ProgressMeter(len(dataloader), [batch_time, losses], prefix="Epoch: [{}]".format(epoch))

    # Initialize GradScaler for automatic mixed precision
    scaler = amp.GradScaler()

    end = time.time()

    # Iterate through the batches
    for i, data in enumerate(dataloader):
        # Efficient data transfer to GPU
        inputs = data['input_spectrum'].float().cuda(args.gpu, non_blocking=True)
        target = data['output_spectrum'].float().cuda(args.gpu, non_blocking=True)

        # Forward pass with mixed precision (autocast)
        with amp.autocast():
            output = net(inputs)
            loss = criterion(output, target)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Backward pass with scaled gradients for stability in AMP
        scaler.scale(loss).backward()

        # Step the optimizer using the scaled loss
        scaler.step(optimizer)
        scaler.update()

        # Update learning rate scheduler, if applicable
        if args.scheduler in ["cyclic-lr", "one-cycle-lr"]:
            scheduler.step()

        # MSE loss calculation outside of AMP (usually in FP32 for stability)
        with torch.no_grad():
            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

        # Measure batch processing time
        batch_time.update(time.time() - end)
        end = time.time()

        # Log progress every 400 iterations
        if i % 400 == 0:
            progress.display(i)

    return losses.avg




In [12]:
def validate(dataloader, net, criterion_MSE, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(len(dataloader), [batch_time, losses], prefix='Validation: ')

    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(dataloader):
            inputs = data['input_spectrum']
            inputs = inputs.float()
            inputs = inputs.cuda(args.gpu)
            target = data['output_spectrum']
            target = target.float()
            target = target.cuda(args.gpu)

            output = net(inputs)

            loss_MSE = criterion_MSE(output, target)
            losses.update(loss_MSE.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 400 == 0:
                progress.display(i)

    return losses.avg

In [62]:
def train_noKmeans(args):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    if args.network == "UNetPlusPlus1D":
        net = UNetPlusPlus1D().float()
    elif args.network == "UNet1D":
        net = UNet1D().float()
    elif args.network == "AttentionUNet":
        net = AttentionUNet().float()
    elif args.network == "TransUNet":
        net = TransUNet(signal_length=500, in_channels=1, out_channels=64, head_num=8, mlp_dim=128, block_num=6, patch_dim=4, class_num=10)
        # net = TransUNet(spectrum_len=500, in_channels=1, out_channels=64, head_num=8, mlp_dim=256, block_num=4, num_classes=1, patch_dim=2).float()
    # elif args.network == "TransUNetBasic":
    #     net = TransUNetBasic(in_channels=1, out_channels=1).float()


    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net.cuda(args.gpu)
        net = torch.nn.parallel.DistributedDataParallel(net)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Train_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Train_Outputs.mat")

    Input = Input_Data['Train_Inputs']
    Output = Output_Data['Train_Outputs']

    spectra_num = len(Input)

    train_split = round(0.9 * spectra_num)
    val_split = round(0.1 * spectra_num)

    input_train = Input[:train_split]
    input_val = Input[train_split:train_split+val_split]

    output_train = Output[:train_split]
    output_val = Output[train_split:train_split+val_split]

    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Train = RamanDataset(input_train, output_train, batch_size = args.batch_size, spectrum_len = args.spectrum_len,
                                   spectrum_shift=0.1, spectrum_window = False, horizontal_flip = False, mixup = True)

    Raman_Dataset_Val = RamanDataset(input_val, output_val, batch_size = args.batch_size, spectrum_len = args.spectrum_len)

# From here down per fold
    train_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)
    val_loader = DataLoader(Raman_Dataset_Val, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)

    # ----------------------------------------------------------------------------------------
    # Define criterion(s), optimizer(s), and scheduler(s)
    # ----------------------------------------------------------------------------------------
    criterion = nn.L1Loss().cuda(args.gpu)
    criterion_MSE = nn.MSELoss().cuda(args.gpu)
    if args.optimizer == "sgd":
        optimizer = optim.SGD(net.parameters(), lr = args.lr)
    elif args.optimizer == "adamW":
        optimizer = optim.AdamW(net.parameters(), lr = args.lr)
    else: # Adam
        optimizer = optim.Adam(net.parameters(), lr = args.lr)

    if args.scheduler == "decay-lr":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
    elif args.scheduler == "multiplicative-lr":
        lmbda = lambda epoch: 0.985
        scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    elif args.scheduler == "cyclic-lr":
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
    elif args.scheduler == "one-cycle-lr":
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
    else: # constant-lr
        scheduler = None

    print('Started Training')
    print('Training Details:')
    print('Network:         {}'.format(args.network))
    print('Epochs:          {}'.format(args.epochs))
    print('Batch Size:      {}'.format(args.batch_size))
    print('Optimizer:       {}'.format(args.optimizer))
    print('Scheduler:       {}'.format(args.scheduler))
    print('Learning Rate:   {}'.format(args.lr))
    print('Spectrum Length: {}'.format(args.spectrum_len))

    DATE = datetime.datetime.now().strftime("%Y_%m_%d")

    formatted_lr = '{:_.6f}'.format(float(args.lr)).rstrip('0').rstrip('.')
    losses_dir = "losses/{}_{}_{}_{}_{}.csv".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network)
    models_dir = "{}_{}_{}_{}_{}.pt".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network)

    df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_loss'])

    # Early stopping
    patience = args.patience if hasattr(args, 'patience') else 10  # Default patience of 10 epochs
    best_val_loss = float('inf')
    epochs_no_improve = 0


    for epoch in range(args.epochs):
        train_loss = train(train_loader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args)
        val_loss = validate(val_loader, net, criterion_MSE, args)
        if args.scheduler == "decay-lr" or args.scheduler == "multiplicative-lr":
            scheduler.step()

        print("Epoch: ", epoch)
        print("Train Loss: ", train_loss)
        print("Val Loss: ", val_loss)
        new_row = pd.DataFrame({'epoch': [epoch], 'train_loss': [train_loss], 'val_loss': [val_loss]})

        df = pd.concat([df, new_row], ignore_index=True)

        # Early Stopping Logic
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered. No improvement in validation loss for {patience} epochs. Finished at epoch {epoch}")
                break

        torch.cuda.empty_cache()

    torch.save(net.state_dict(), models_dir)
    df.to_csv(losses_dir, index=False)
    print('Finished Training')

In [15]:
def train_kmeans(args, k_folds = 2):

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()

    gpu = args.gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Train_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Train_Outputs.mat")

    Input = Input_Data['Train_Inputs']
    Output = Output_Data['Train_Outputs']


    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Train = RamanDataset(Input, Output, batch_size = args.batch_size, spectrum_len = args.spectrum_len,
                                   spectrum_shift=0.1, spectrum_window = False, horizontal_flip = False, mixup = True)

    # Raman_Dataset_Val = RamanDataset(input_val, output_val, batch_size = args.batch_size, spectrum_len = args.spectrum_len)

# From here down per fold
    kf = KFold(n_splits=k_folds, shuffle=True)
    for fold, (train_idx, test_idx) in enumerate(kf.split(Raman_Dataset_Train)):
      print(f"Fold {fold + 1}")
      print("-------")

      train_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
      val_loader = DataLoader(Raman_Dataset_Train, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True, sampler=torch.utils.data.SubsetRandomSampler(test_idx))

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
      if args.network == "UNetPlusPlus1D":
        net = UNetPlusPlus1D().float()
      elif args.network == "UNet1D":
        net = UNet1D().float()
      elif args.network == "AttentionUNet":
        net = AttentionUNet().float()
      elif args.network == "TransUNet":
        net = TransUNet(args.spectrum_len, 1, 64, 8, 256, 4, 16, 1)

      if args.distributed:
          if args.gpu is not None:
              torch.cuda.set_device(args.gpu)
              args.batch_size = int(args.batch_size / ngpus_per_node)
              args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

              net.cuda(args.gpu)
              net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
          else:
              net.cuda(args.gpu)
              net = torch.nn.parallel.DistributedDataParallel(net)
      elif args.gpu is not None:
          torch.cuda.set_device(args.gpu)
          net.cuda(args.gpu)
      else:
          net.cuda(args.gpu)
          net = torch.nn.parallel.DistributedDataParallel(net)



      # ----------------------------------------------------------------------------------------
      # Define criterion(s), optimizer(s), and scheduler(s)
      # ----------------------------------------------------------------------------------------
      criterion = nn.L1Loss().cuda(args.gpu)
      criterion_MSE = nn.MSELoss().cuda(args.gpu)
      if args.optimizer == "sgd":
          optimizer = optim.SGD(net.parameters(), lr = args.lr)
      elif args.optimizer == "adamW":
          optimizer = optim.AdamW(net.parameters(), lr = args.lr)
      else: # Adam
          optimizer = optim.Adam(net.parameters(), lr = args.lr)

      if args.scheduler == "decay-lr":
          scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
      elif args.scheduler == "multiplicative-lr":
          lmbda = lambda epoch: 0.985
          scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
      elif args.scheduler == "cyclic-lr":
          scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr = args.base_lr, max_lr = args.lr, mode = 'triangular2', cycle_momentum = False)
      elif args.scheduler == "one-cycle-lr":
          scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = args.lr, steps_per_epoch=len(train_loader), epochs=args.epochs, cycle_momentum = False)
      else: # constant-lr
          scheduler = None

      print('Started Training')
      print('Training Details:')
      print('Network:         {}'.format(args.network))
      print('Epochs:          {}'.format(args.epochs))
      print('Batch Size:      {}'.format(args.batch_size))
      print('Optimizer:       {}'.format(args.optimizer))
      print('Scheduler:       {}'.format(args.scheduler))
      print('Learning Rate:   {}'.format(args.lr))
      print('Spectrum Length: {}'.format(args.spectrum_len))

      DATE = datetime.datetime.now().strftime("%Y_%m_%d")

      formatted_lr = '{:_.6f}'.format(float(args.lr)).rstrip('0').rstrip('.')

      losses_dir = "losses/{}_{}_{}_{}_{}_fold_{}.csv".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, fold + 1)
      models_dir = "{}_{}_{}_{}_{}_fold_{}.pt".format(DATE, args.optimizer, args.scheduler, formatted_lr, args.network, fold + 1)

      df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_loss'])

      patience = args.patience if hasattr(args, 'patience') else 10  # Default patience of 10 epochs
      best_val_loss = float('inf')
      epochs_no_improve = 0

      for epoch in range(args.epochs):
          train_loss = train(train_loader, net, optimizer, scheduler, criterion, criterion_MSE, epoch, args)
          val_loss = validate(val_loader, net, criterion_MSE, args)
          if args.scheduler == "decay-lr" or args.scheduler == "multiplicative-lr":
              scheduler.step()

          print("Epoch: ", epoch)
          print("Train Loss: ", train_loss)
          print("Val Loss: ", val_loss)
          new_row = pd.DataFrame({'epoch': [epoch], 'train_loss': [train_loss], 'val_loss': [val_loss]})

          df = pd.concat([df, new_row], ignore_index=True)

                  # Early Stopping Logic
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              epochs_no_improve = 0
          else:
              epochs_no_improve += 1
              if epochs_no_improve >= patience:
                  print(f"Early stopping triggered. No improvement in validation loss for {patience} epochs. Finished at epoch {epoch}")
                  break

          torch.cuda.empty_cache()

      torch.save(net.state_dict(), models_dir)
      df.to_csv(losses_dir, index=False)
      print('Finished Training')

In [14]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [15]:
%ls

dataset.py  model.py  ResUNet.pt  utilities.py


In [16]:
%cd ../..

/content


In [17]:
%cd drive/My\ Drive/Colab\ Notebooks/DeepeR-master/Raman Spectral Denoising

/content/drive/My Drive/Colab Notebooks/DeepeR-master/Raman Spectral Denoising


In [65]:
# Default args from original code
#Namespace(workers=0, epochs=2, start_epoch=0, batch_size=256, network='ResUNet', optimizer='adam', lr=0.0005, base_lr=5e-06, scheduler='one-cycle-lr', batch_norm=True, spectrum_len=500, seed=None, gpu=0, world_size=-1, rank=-1, dist_url='tcp://224.66.41.62:23456', dist_backend='nccl', multiprocessing_distributed=False)

class Arguments:
    pass

args = Arguments()
args.workers = 0
args.epochs = 500
args.start_epoch = 0
args.batch_size = 256
args.network = "TransUNet"
args.optimizer = "adamW"
args.lr = 1e-4
args.base_lr = 5e-6
args.scheduler = "one-cycle-lr"
args.batch_norm = True
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False
args.patience = 10


args.epochs=2
train_noKmeans(args)
# train_kmeans(args)

Use GPU: 0 for training
Started Training
Training Details:
Network:         TransUNet
Epochs:          2
Batch Size:      256
Optimizer:       adamW
Scheduler:       one-cycle-lr
Learning Rate:   0.0001
Spectrum Length: 500


  scaler = amp.GradScaler()
  with amp.autocast():


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x512 and 125x512)

In [66]:
!pip install torch torchvision



In [None]:
from google.colab import runtime
runtime.unassign()

In [18]:
def calc_psnr(output, target):
    psnr = 0.
    mse = nn.MSELoss()(output, target)
    psnr = 10 * math.log10(torch.max(output)/mse)
    return psnr

def calc_ssim(output, target):
    ssim = 0.
    output = output.cpu().detach().numpy()
    target = target.cpu().detach().numpy()

    if output.ndim == 4:
        for i in range(output.shape[0]):
            output_i = np.squeeze(output[i,:,:,:])
            output_i = np.moveaxis(output_i, 0, -1)
            target_i = np.squeeze(target[i,:,:,:])
            target_i = np.moveaxis(target_i, 0, -1)
            batch_size = output.shape[0]
            ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)
    else:
        output_i = np.squeeze(output)
        output_i = np.moveaxis(output_i, 0, -1)
        target_i = np.squeeze(target)
        target_i = np.moveaxis(target_i, 0, -1)
        batch_size = 1
        ssim += sk_ssim(output_i, target_i, data_range = output_i.max() - target_i.max(), multichannel=True)

    ssim = ssim / batch_size
    return ssim

In [19]:
# Testing
def evaluate(dataloader, net, args):
    losses = AverageMeter('Loss', ':.4e')
    psnr = AverageMeter('PSNR', ':.4f')
    ssim = AverageMeter('SSIM', ':.4f')
    SG_loss = AverageMeter('Savitzky-Golay Loss', ':.4e')

    net.eval()

    MSE_SG = []

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            x = data['input_spectrum']
            inputs = x.float()
            inputs = inputs.cuda(args.gpu)
            y = data['output_spectrum']
            target = y.float()
            target = target.cuda(args.gpu)

            x = np.squeeze(x.numpy())
            y = np.squeeze(y.numpy())

            output = net(inputs)
            loss = nn.MSELoss()(output, target)

            x_out = output.cpu().detach().numpy()
            x_out = np.squeeze(x_out)

            SGF_1_9 = scipy.signal.savgol_filter(x,9,1)
            MSE_SGF_1_9 = np.mean(np.mean(np.square(np.absolute(y - (SGF_1_9 - np.reshape(np.amin(SGF_1_9, axis = 1), (len(SGF_1_9),1)))))))
            MSE_SG.append(MSE_SGF_1_9)

            psnr_batch = calc_psnr(output, target)
            psnr.update(psnr_batch, inputs.size(0))
            ssim_batch = calc_ssim(output, target)
            ssim.update(ssim_batch, inputs.size(0))

            losses.update(loss.item(), inputs.size(0))

        print("Neural Network MSE: {}".format(losses.avg))
        print("Neural Network PSNR: {}".format(psnr.avg))
        print("Neural Network SSIM: {}".format(ssim.avg))
        print("Savitzky-Golay MSE: {}".format(np.mean(np.asarray(MSE_SG))))
        print("Neural Network performed {0:.2f}x better than Savitzky-Golay".format(np.mean(np.asarray(MSE_SG))/losses.avg))

    return losses.avg, psnr.avg, ssim.avg, MSE_SG

In [20]:
def main_test(args):
    gpu = args.gpu
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()


    if args.gpu is not None:
        print("Use GPU: {} for testing".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # ----------------------------------------------------------------------------------------
    # Create model(s) and send to device(s)
    # ----------------------------------------------------------------------------------------
    if args.network == "UNetPlusPlus1D":
      net = UNetPlusPlus1D().float()
    elif args.network == "UNet1D":
      net = UNet1D().float()
    elif args.network == "AttentionUNet":
      net = AttentionUNet().float()
    net.load_state_dict(torch.load(args.model))

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        else:
            net.cuda(args.gpu)
            net = torch.nn.parallel.DistributedDataParallel(net)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
    else:
        net.cuda(args.gpu)
        net = torch.nn.parallel.DistributedDataParallel(net)

    # ----------------------------------------------------------------------------------------
    # Define dataset path and data splits
    # ----------------------------------------------------------------------------------------
    Input_Data = scipy.io.loadmat("Dataset/Test_Inputs.mat")
    Output_Data = scipy.io.loadmat("Dataset/Test_Outputs.mat")

    Input = Input_Data['Test_Inputs']
    Output = Output_Data['Test_Outputs']

    # ----------------------------------------------------------------------------------------
    # Create datasets (with augmentation) and dataloaders
    # ----------------------------------------------------------------------------------------
    Raman_Dataset_Test = RamanDataset(Input, Output, batch_size = args.batch_size, spectrum_len = args.spectrum_len)

    test_loader = DataLoader(Raman_Dataset_Test, batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)

    # ----------------------------------------------------------------------------------------
    # Evaluate
    # ----------------------------------------------------------------------------------------
    MSE_NN, PSNR_NN, SSIM_NN, MSE_SG = evaluate(test_loader, net, args)

In [23]:
class Arguments:
    pass

args = Arguments()
args.workers = 0
args.batch_size = 256
args.spectrum_len = 500
args.seed = None
args.gpu = 0
args.world_size = -1
args.rank = -1
args.dist_url = "tcp://224.66.41.62:23456"
args.dist_backend = "nccl"
args.multiprocessing_distributed = False
args.batch_norm = True
args.network = "UNet1D"
args.model = "2024_10_09_adamW_one-cycle-lr_0.001_UNet1D.pt"


main_test(args)

Use GPU: 0 for testing


  net.load_state_dict(torch.load(args.model))


Neural Network MSE: 0.0023067386205582094
Neural Network PSNR: 26.385451591812636
Neural Network SSIM: 0.2742282462580637
Savitzky-Golay MSE: 0.027660622850368643
Neural Network performed 11.99x better than Savitzky-Golay


In [None]:
file_name = 'losses/training_logs.csv'

df = pd.read_csv(file_name)

# Plotting train_loss and val_loss
plt.plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o')
plt.plot(df['epoch'], df['val_loss'], label='Validation Loss', marker='o')

# Adding labels and title
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss vs Epoch')

# Adding a legend to distinguish the two lines
plt.legend()

# Show the plot
plt.show()
