In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os

In [None]:
class CNN2D_AIRS(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNN2D_AIRS, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.gn1 = nn.GroupNorm(4, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gn2 = nn.GroupNorm(4, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # x = self.pool(F.relu(self.gn1(self.conv1(x))))
        # x = self.pool(F.relu(self.gn2(self.conv2(x))))

        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        return x

class CNN2D_FGS(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CNN2D_FGS, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.gn1 = nn.GroupNorm(4, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gn2 = nn.GroupNorm(4, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # x = self.pool(F.relu(self.gn1(self.conv1(x))))
        # x = self.pool(F.relu(self.gn2(self.conv2(x))))

        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        return x

class Model(nn.Module):
    def __init__(self, airs_frames, fgs_frames):
        super(Model, self).__init__()
        self.cnn_airs = CNN2D_AIRS(1, 16)
        self.cnn_fgs = CNN2D_FGS(1, 16)
        
        self.lstm_airs = nn.LSTM(16 * 8 * 89, 128, batch_first=True)
        self.lstm_fgs = nn.LSTM(16 * 8 * 8, 128, batch_first=True)
        
        self.ln_airs = nn.LayerNorm(128)
        self.ln_fgs = nn.LayerNorm(128)
        
        self.fc_light_curve_airs = nn.Sequential(
            nn.Linear(airs_frames, 64),
            nn.ReLU(),
#             nn.LayerNorm(64),
        )
        
        self.fc_light_curve_fgs = nn.Sequential(
            nn.Linear(fgs_frames, 64),
            nn.ReLU(),
#             nn.LayerNorm(64),
        )
        
        self.fc_combined = nn.Sequential(            
            nn.Linear(128 + 128 + 64 + 64, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 283)
        )
    def forward(self, airs_ch0, fgs1, light_curve_airs, light_curve_fgs):
        batch_size, frames, _, _, _ = airs_ch0.shape
        
        airs_features = self.cnn_airs(airs_ch0.view(-1, 1, 32, 356))
        airs_features = airs_features.view(batch_size, frames, -1)
        _, (airs_hidden, _) = self.lstm_airs(airs_features)
        # airs_hidden = self.ln_airs(airs_hidden.squeeze(0))
        airs_hidden = airs_hidden.squeeze(0)

        
        fgs_features = self.cnn_fgs(fgs1.view(-1, 1, 32, 32))
        fgs_features = fgs_features.view(batch_size, frames, -1)
        _, (fgs_hidden, _) = self.lstm_fgs(fgs_features)
        # fgs_hidden = self.ln_fgs(fgs_hidden.squeeze(0))
        fgs_hidden = fgs_hidden.squeeze(0)

        
        light_curve_airs_features = self.fc_light_curve_airs(light_curve_airs)
        light_curve_fgs_features = self.fc_light_curve_fgs(light_curve_fgs)
        
        combined_features = torch.cat((airs_hidden, fgs_hidden, light_curve_airs_features, light_curve_fgs_features), dim=1)
        
        output = self.fc_combined(combined_features)
        return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
airs_frames = 1125
fgs_frames = 1125
model = Model(airs_frames, fgs_frames).to(device) # small batch size model

In [None]:
# class CNN2D_AIRS(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(CNN2D_AIRS, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
#         self.bn1 = nn.BatchNorm2d(out_channels)
#         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
#         self.bn2 = nn.BatchNorm2d(out_channels)
#         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

#     def forward(self, x):
# #         x = self.pool(F.relu(self.bn1(self.conv1(x))))
# #         x = self.pool(F.relu(self.bn2(self.conv2(x))))

#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))

#         return x

# class CNN2D_FGS(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(CNN2D_FGS, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
#         self.bn1 = nn.BatchNorm2d(out_channels)
#         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
#         self.bn2 = nn.BatchNorm2d(out_channels)
#         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

#     def forward(self, x):
# #         x = self.pool(F.relu(self.bn1(self.conv1(x))))
# #         x = self.pool(F.relu(self.bn2(self.conv2(x))))

#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         return x

# class Model(nn.Module):
#     def __init__(self, airs_frames, fgs_frames):
#         super(Model, self).__init__()
#         self.cnn_airs = CNN2D_AIRS(1, 16)
#         self.cnn_fgs = CNN2D_FGS(1, 16)
        
#         self.lstm_airs = nn.LSTM(16 * 8 * 89, 128, batch_first=True)
#         self.lstm_fgs = nn.LSTM(16 * 8 * 8, 128, batch_first=True)
        
#         self.bn_airs = nn.BatchNorm1d(128)
#         self.bn_fgs = nn.BatchNorm1d(128)
        
#         self.fc_light_curve_airs = nn.Sequential(
#             nn.Linear(airs_frames, 64),
#             nn.ReLU()
# #             nn.BatchNorm1d(64),

#         )
        
#         self.fc_light_curve_fgs = nn.Sequential(
#             nn.Linear(fgs_frames, 64),
#             nn.ReLU()
# #             nn.BatchNorm1d(64),

#         )
        
#         self.fc_combined = nn.Sequential(            
#             nn.Linear(128 + 128 + 64 + 64, 256),
#             nn.ReLU(),
#             nn.BatchNorm1d(256),
# #             nn.Dropout(0.15),
#             nn.Linear(256, 256),
#             nn.ReLU(),
#             nn.BatchNorm1d(256),
# #             nn.Dropout(0.15),
#             nn.Linear(256, 283)
#         )

#     def forward(self, airs_ch0, fgs1, light_curve_airs, light_curve_fgs):
#         batch_size, frames, _, _, _ = airs_ch0.shape
        
#         airs_features = self.cnn_airs(airs_ch0.view(-1, 1, 32, 356))
#         airs_features = airs_features.view(batch_size, frames, -1)
#         _, (airs_hidden, _) = self.lstm_airs(airs_features)
# #         airs_hidden = self.bn_airs(airs_hidden.squeeze(0))
#         airs_hidden = airs_hidden.squeeze(0)
        
        
#         fgs_features = self.cnn_fgs(fgs1.view(-1, 1, 32, 32))
#         fgs_features = fgs_features.view(batch_size, frames, -1)
#         _, (fgs_hidden, _) = self.lstm_fgs(fgs_features)
# #         fgs_hidden = self.bn_fgs(fgs_hidden.squeeze(0))
#         fgs_hidden = fgs_hidden.squeeze(0)
        
#         light_curve_airs_features = self.fc_light_curve_airs(light_curve_airs)
#         light_curve_fgs_features = self.fc_light_curve_fgs(light_curve_fgs)
        
#         combined_features = torch.cat((airs_hidden, fgs_hidden, light_curve_airs_features, light_curve_fgs_features), dim=1)
        
#         output = self.fc_combined(combined_features)
#         return output

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# airs_frames = 1125
# fgs_frames = 1125
# model = Model(airs_frames, fgs_frames).to(device) # for larger batch size 

In [None]:
c=0
for x in model.parameters():
    c+=x.numel()
print(c)

In [None]:
def planetnumber(filename):
    return int(filename.split('_')[0])

class ARIEL(Dataset):
    def __init__(self, airs_dir1, airs_dir2, airs_dir3, airs_dir4, fgs_dir , start , end):
        self.airs_dir1 = airs_dir1
        self.airs_dir2 = airs_dir2
        self.airs_dir3 = airs_dir3
        self.airs_dir4 = airs_dir4
        self.airs_full = os.listdir(self.airs_dir1) + os.listdir(self.airs_dir2) + os.listdir(self.airs_dir3) + os.listdir(self.airs_dir4) 
        
        self.fgs_dir = fgs_dir
        
        self.airs_list = sorted(self.airs_full, key=planetnumber)[start:end]
        self.fgs_list = sorted(os.listdir(self.fgs_dir), key=planetnumber)[start:end]
        
        self.labels = pd.read_csv("/kaggle/input/ariel-data-challenge-2024/train_labels.csv")
    
    def __getitem__(self, index):
        
        planet= self.airs_list[index]
        
        if planet in os.listdir(self.airs_dir1):
            airs_file = os.path.join(self.airs_dir1, planet )
        
        elif planet in os.listdir(self.airs_dir2):
            airs_file = os.path.join(self.airs_dir2, planet )
            
        elif planet in os.listdir(self.airs_dir3):
            airs_file = os.path.join(self.airs_dir3, planet )
            
        elif planet in os.listdir(self.airs_dir4):
            airs_file = os.path.join(self.airs_dir4, planet )

        
                    
        planet_num = planetnumber(planet)
        fgs_file = f"{self.fgs_dir}/{planet_num}_fgs.npy" 
        
        airs_arr_frames = np.load(airs_file)
        fgs_arr_frames = np.load(fgs_file)
        
        airs_arr_frames = airs_arr_frames.reshape(1125, 32, 356)
        fgs_arr_frames = fgs_arr_frames.reshape(1125, 32, 32)
        
        airs_1d = np.sum(airs_arr_frames, axis=(1, 2))
        fgs_1d = np.sum(fgs_arr_frames, axis=(1, 2))
        
        airs_1d = (airs_1d-np.min(airs_1d))/(np.max(airs_1d)-np.min(airs_1d))
        fgs_1d  = (fgs_1d-np.min(fgs_1d))/(np.max(fgs_1d)-np.min(fgs_1d))

        
        airs_arr_frames = torch.from_numpy(airs_arr_frames).float().unsqueeze(1)  # Add channel dimension
        fgs_arr_frames = torch.from_numpy(fgs_arr_frames).float().unsqueeze(1)  # Add channel dimension
        
        airs_1d = torch.from_numpy(airs_1d).float()
        fgs_1d = torch.from_numpy(fgs_1d).float()
        
        filtered_data = self.labels[self.labels["planet_id"] == planet_num].iloc[0, 1:].values
        output = torch.tensor(filtered_data).float()
        
        # return  [planet , airs_file , fgs_file]
        return {
            'airs_frames': airs_arr_frames,
            'fgs_frames': fgs_arr_frames,
            'airs_1d': airs_1d,
            'fgs_1d': fgs_1d,
            'label': output
        }
     
    def __len__(self):
        return len(self.airs_list)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

airs_frames = 1125
fgs_frames = 1125

model = Model(airs_frames, fgs_frames).to(device)
model = nn.DataParallel(model)
model = model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2 , weight_decay=1e-4) 
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.50, patience=3 , min_lr=1e-8)



weights = None  
# weights="/kaggle/input/arieldata/epoch450.pth"


if weights:
    checkpoint = torch.load(weights, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
    print(f"Resuming learning rate: {optimizer.param_groups[0]['lr']}")
else:
    start_epoch = 0

train_batchsize = 5
val_batchsize   = 2

part1="/kaggle/input/arieldata/airs-p1"
part2="/kaggle/input/arieldata/airs-p2"
part3="/kaggle/input/arieldata/airs-p3"
part4="/kaggle/input/arieldata/airs-p4"

part5="/kaggle/input/arieldata/fgs-p"

train_data = ARIEL(part1 , part2, part3, part4, part5, start=0 , end=635)
val_data   = ARIEL(part1 , part2, part3, part4, part5, start=635 , end=665)

train_dataloader = DataLoader(train_data, batch_size=train_batchsize, shuffle=True, num_workers=8)
val_dataloader   = DataLoader(val_data,   batch_size=val_batchsize, shuffle=False, num_workers=8)

print(f"Training batches: {len(train_dataloader)}, Validation batches: {len(val_dataloader)}")


In [None]:
import glob

epochs = 141
total = start_epoch + epochs
print("training started")
best_val_loss = float('inf')
patience = 15
no_improve = 0
count = 0

def cleanup_old_checkpoints(directory):
    """Remove all .pth files in the specified directory"""
    checkpoint_files = glob.glob(os.path.join(directory, "*.pth"))
    for f in checkpoint_files:
        try:
            os.remove(f)
            print(f"Removed old checkpoint: {f}")
        except Exception as e:
            print(f"Error removing {f}: {e}")

for epoch in range(start_epoch, total):
    model.train()
    train_loss = 0
    val_loss = 0
    
    for batch in train_dataloader:
        optimizer.zero_grad()
        airs_frames = batch['airs_frames'].to(device)
        fgs_frames = batch['fgs_frames'].to(device)
        airs_1d = batch['airs_1d'].to(device)
        fgs_1d = batch['fgs_1d'].to(device)
        label = batch['label'].to(device)
        out = model(airs_frames, fgs_frames, airs_1d, fgs_1d)
        
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    train_loss /= len(train_dataloader)
    
    if epoch%5==0 and epoch>0:
        print(f" label {(label[0][:3].cpu().detach().numpy())} , output {(out[0][:3].cpu().detach().numpy())}")
        
    if no_improve==7 and count==0 and optimizer.param_groups[0]['lr'] > 1e-6:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 5e-7
            print(f"Learning rate manually set to 5e-7 at epoch {epoch}")
        count+=1
    
    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            airs_frames = batch['airs_frames'].to(device)
            fgs_frames = batch['fgs_frames'].to(device)
            airs_1d = batch['airs_1d'].to(device)
            fgs_1d = batch['fgs_1d'].to(device)
            label = batch['label'].to(device)
            out = model(airs_frames, fgs_frames, airs_1d, fgs_1d)
            loss = criterion(out, label)
            val_loss += loss.item()
    
    val_loss /= len(val_dataloader)
    prev = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    nex = optimizer.param_groups[0]['lr']
    
    if prev!=nex:
        print("LR decreased to ", nex)
    
    print(f"Epoch {epoch+1}/{total}, Train loss: {train_loss}, Val loss: {val_loss}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        
        cleanup_old_checkpoints("/kaggle/working")
        
        model_filename = f"epoch{epoch}.pth"
        model_path = os.path.join("/kaggle/working", model_filename)
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': train_loss,
        }
        torch.save(checkpoint, model_path)        
        print(f"Model saved at epoch {epoch}, old files clean, {len(os.listdir('/kaggle/working'))}")
        
    else:
        no_improve += 1
        if no_improve == patience:
            print("Early stopping triggered at epoch", epoch)
            break

In [None]:
model_filename = "lastrun.pth"
model_path = os.path.join("/kaggle/working", model_filename)
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': train_loss,
}
torch.save(checkpoint, model_path)        
print(f"Model saved at epoch {epoch}")


In [None]:
# model.eval()  
# model_filename = f"epoch{epoch}-loss{epoch_loss:.8f}.pth"
# model_path = os.path.join("/kaggle/working", model_filename)

# checkpoint = {
#     'epoch': epoch,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'scheduler_state_dict': scheduler.state_dict(),
#     'loss': epoch_loss,
# }
# torch.save(checkpoint, model_path)        
# print(f"Model saved at epoch {epoch}")
# print()

In [None]:
print(f"Learning rate final to {optimizer.param_groups[0]['lr']} ")
