In [2]:
import torch
from torchinfo import summary
from tqdm import tqdm
import TorchData
from sklearn.metrics import roc_auc_score

In [13]:
import importlib
importlib.reload(TorchData)

<module 'TorchData' from 'F:\\projects\\ResearchProject\\TorchData\\__init__.py'>

In [14]:
device = "cuda" if torch.cuda.is_available() else "mps"

In [15]:
train_dataset = TorchData.TORCHDataset2Channel(num_data=5096)
test_dataset = TorchData.TORCHDataset2Channel(num_data=1024)

train_dataloader = train_dataset.dataloader(batch_size=256, shuffle=True)
test_dataloader = test_dataset.dataloader(batch_size=1, shuffle=False)

In [16]:
class ConvolutionAutoencoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(2, 64, 3, stride=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 16, 3, stride=1),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 8, 3, stride=1),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            #torch.nn.Flatten()
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(8, 16, 3, stride=1),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(16, 64, 3, stride=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(64, 2, 3, stride=1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [17]:
model = ConvolutionAutoencoder().to(device)
loss_function = torch.nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

In [18]:
summary(model, input_size=(1, 2, 120, 92))

Layer (type:depth-idx)                   Output Shape              Param #
ConvolutionAutoencoder                   [1, 2, 120, 92]           --
├─Sequential: 1-1                        [1, 8, 114, 86]           --
│    └─Conv2d: 2-1                       [1, 64, 118, 90]          1,216
│    └─BatchNorm2d: 2-2                  [1, 64, 118, 90]          128
│    └─ReLU: 2-3                         [1, 64, 118, 90]          --
│    └─Conv2d: 2-4                       [1, 16, 116, 88]          9,232
│    └─BatchNorm2d: 2-5                  [1, 16, 116, 88]          32
│    └─ReLU: 2-6                         [1, 16, 116, 88]          --
│    └─Conv2d: 2-7                       [1, 8, 114, 86]           1,160
│    └─BatchNorm2d: 2-8                  [1, 8, 114, 86]           16
│    └─ReLU: 2-9                         [1, 8, 114, 86]           --
├─Sequential: 1-2                        [1, 2, 120, 92]           --
│    └─BatchNorm2d: 2-10                 [1, 8, 114, 86]           16
│    

In [19]:
for epoch in tqdm(range(100)):
    for x, y in train_dataloader:
        optimiser.zero_grad()
        outputs = model(x.to(device))
        loss = loss_function(outputs, y.to(device))
        loss.backward()
        optimiser.step()

  return torch.tensor(self.x[idx]), torch.tensor(self.y[idx])
  2%|▏         | 2/100 [00:13<11:22,  6.96s/it]


KeyboardInterrupt: 

In [None]:
### Using pretrained weights
model.load_state_dict(torch.load("ModelWeights/Conv2Channel/hybrid_transformer_weights_2000.pth", weights_only=True))

In [None]:
model.eval()
with torch.no_grad():
    all_pred = model.to("cpu")(test_dataset.sn_time.to("cpu"))

all_pred_time = all_pred * test_dataset.sn_time
all_pred_time[all_pred_time < 0.5] = 0

In [None]:
TorchData.visual.fast_compare_plot(test_dataset, all_pred_time, 0)

In [None]:
mse = TorchData.metric.calculate_mse_torch(test_dataset.signal_time, all_pred_time)
psnr = TorchData.metric.calculate_psnr_torch(test_dataset.signal_time, all_pred_time)
ssim = TorchData.metric.calculate_ssim_torch(test_dataset.signal_time, all_pred_time)
roc_auc = roc_auc_score(test_dataset.signal.flatten().numpy(), all_pred_time.flatten().numpy())
print(f"MSE: {mse:.4f}, PSNR: {psnr:.4f}, SSIM: {ssim:.4f}, ROC AUC: {roc_auc:.4f}")