# Imports

In [None]:
from efficientnet_pytorch import EfficientNet
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim

# Dataset

In [None]:


class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        spectrogram = np.load(file_path)  
        if self.transform:
            spectrogram = self.transform(spectrogram)
        return spectrogram


file_paths = [os.path.join("spectrograms", f) for f in os.listdir("spectrograms") if f.endswith(".npy")]
dataset = SpectrogramDataset(file_paths)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)


# EfficientNetB3

In [None]:
INPUT_FILE = '../data/cleaned/70_15_15_cleaned_train.parquet'
df = pd.read_parquet(INPUT_FILE)
class_count = len(df['B'].unique())
num_classes = class_count 

model = EfficientNet.from_pretrained('efficientnet-b3')
in_features = model._fc.in_features
model._fc = torch.nn.Linear(in_features, num_classes)


## Preprocess Spectrogram data

In [None]:


transform = transforms.Compose([
    transforms.ToPILImage(),                 
    transforms.Resize((300, 300)),         
    transforms.ToTensor(),                  
    transforms.Normalize(                   # Normalize with ImageNet stats
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])


## Training

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")
