In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F
import numpy
import json
import os
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [None]:
!unzip pchdata.zip -d pchdata
!find pchdata/pchdata -type d -empty -delete

In [9]:
class P1aSet(Dataset):
    def __init__(self, root):
        self.root = root
        self.folders = sorted(os.listdir(self.root))
    
    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        with open(os.path.join(self.root, self.folders[idx], "GT.json"), "r") as f:
            sim_data = json.load(f)
            pch = sim_data['Data']['pch_bins']
            tdist = sim_data['Data']['true_bins']

        return torch.tensor(pch, dtype=torch.float32), torch.tensor(tdist, dtype=torch.float32)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        if in_channels != out_channels:
            self.projection = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.projection = None
        self.gelu = nn.GELU()

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.gelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.projection:
            identity = self.projection(identity)
        out += identity
        out = self.gelu(out)
        return out

class BASCNN(nn.Module):
    def __init__(self, input_length=50, output_dim=100):
        super().__init__()
        self.resblock1 = ResidualBlock(1, 16)
        self.dropout1 = nn.Dropout(0.1)
        self.resblock2 = ResidualBlock(16, 32)
        self.dropout2 = nn.Dropout(0.1)
        #self.resblock3 = ResidualBlock(32, 64)
        #self.dropout3 = nn.Dropout(0.1)
        
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * input_length, output_dim)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x, return_log=False):
        x = x.unsqueeze(1)  # (batch_size, 1, input_length)
        x = self.dropout1(self.resblock1(x))
        x = self.dropout2(self.resblock2(x))
        #x = self.dropout3(self.resblock3(x))

        x = self.flatten(x)  # (batch_size, 128 * input_length)
        x = self.dropout(x)
        x = self.fc(x)  # (batch_size, output_dim)
        if return_log:
            return F.log_softmax(x, dim=1)
        else:
            return F.softmax(x, dim=1)

In [None]:
lr = 2e-5
BATCH_SIZE = 16
NUM_EPOCHS=200
model = BASCNN()

optimizer = torch.optim.AdamW(list(model.parameters()), lr=lr, weight_decay=1e-3)
criterion = nn.MSELoss()

dataset = P1aSet(r"C:\Users\omgui\Desktop\BASDL\data_gen_phase1a\pchdata")

train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

val_size = int(val_ratio * len(dataset))
test_size = int(test_ratio * len(dataset))
train_size = len(dataset) - val_size - test_size

train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

# scheduler


In [None]:
model.to(device)
train_losses = []
val_losses = []

for epoch in tqdm(range(NUM_EPOCHS)):
    total_loss = 0.0
    val_loss = 0.0
    model.train()
    for i, (pch, tdist) in train_loader:
        pch, tdist = pch.to(device), tdist.to(device)
        pred_dist = model(pch)

        loss = criterion(pred_dist, tdist)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    model.eval()
    with torch.no_grad():
        for pch, tdist in val_loader:
            pch, tdist = pch.to(device), tdist.to(device)
            pred_dist = model(pch)
            loss = criterion(pred_dist, tdist)
            val_los += loss.item()

    
    print(f"Epoch {epoch} | Loss: {total_loss / len(train_loader)} | Val: {val_loss / len(val_loader)}")
    
    train_losses.append(total_loss / len(train_loader))
    val_losses.append(val_loss / len(val_loader))

