In [1]:
from models.ResNet import get_quantized_resnet18_model
from models.DenseNet import get_densenet121_model
from models.EfficientNet import get_efficientnet_b2_model
from train import train_model
from data_loader import train_loader, test_loader
import torch
from utils.class_weights import get_class_weights

# Check if quantized model is used
USE_QUANTIZED = False

# Use CPU for quantized models, CUDA otherwise
device = torch.device('cpu') if USE_QUANTIZED else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if USE_QUANTIZED:
    model = get_quantized_resnet18_model(num_classes=2)
    model, final_transform = get_quantized_resnet18_model(num_classes=2)
    train_loader.dataset.final_transform = final_transform
    test_loader.dataset.final_transform = final_transform

else:
    model = get_efficientnet_b2_model(num_classes=2)  # or get_efficientnet()



model.to(device)  # Safe now
model.eval()      # Also important for quantized inference

# Class weights
class_weights = get_class_weights(train_loader.dataset, num_classes=2, device=str(device))

Training set size: 960
Test set size: 240
Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-c35c1473.pth


100%|██████████| 35.2M/35.2M [00:00<00:00, 111MB/s] 


In [2]:
trained_model = train_model(model, train_loader, test_loader, num_epochs=20, learning_rate=2e-4, device=str(device), class_weights=class_weights, model_id='efficient_net_b2')

Epoch 1/20, Loss: 0.3697, Val Loss: 0.6210, Val Accuracy: 0.7583, Val AUC: 0.8263
✅ Validation loss improved, saving best model.
Epoch 2/20, Loss: 0.2290, Val Loss: 0.3625, Val Accuracy: 0.8000, Val AUC: 0.9137
✅ Validation loss improved, saving best model.
Epoch 3/20, Loss: 0.1219, Val Loss: 0.5101, Val Accuracy: 0.8417, Val AUC: 0.9316
⚠️ No improvement in val loss for 1 epoch(s).
Epoch 4/20, Loss: 0.0788, Val Loss: 0.5794, Val Accuracy: 0.8417, Val AUC: 0.9343
⚠️ No improvement in val loss for 2 epoch(s).
Epoch 5/20, Loss: 0.1223, Val Loss: 0.5536, Val Accuracy: 0.8458, Val AUC: 0.9163
⚠️ No improvement in val loss for 3 epoch(s).
Epoch 6/20, Loss: 0.0830, Val Loss: 0.2827, Val Accuracy: 0.9042, Val AUC: 0.9634
✅ Validation loss improved, saving best model.
Epoch 7/20, Loss: 0.0764, Val Loss: 0.2604, Val Accuracy: 0.9083, Val AUC: 0.9656
✅ Validation loss improved, saving best model.
Epoch 8/20, Loss: 0.0690, Val Loss: 0.3874, Val Accuracy: 0.8958, Val AUC: 0.9594
⚠️ No improvement 

In [3]:
import torch

In [4]:
torch.save(trained_model.state_dict(), "saved_models/efficient_net_b2.pth")
print("Model saved successfully!")

Model saved successfully!
