In [1]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize
from tqdm import tqdm
from sklearn.model_selection import train_test_split


In [2]:

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BASE_DIR = r"C:\Project\Enhanced_Dataset\dataset"
IMAGE_SIZE = (150, 150)
BATCH_SIZE = 64
EPOCHS = 5
CLASSES = ['no', 'sphere', 'vort']
NUM_CLASSES = len(CLASSES)
LEARNING_RATE = 5e-5  # Adjusted for stability

In [3]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(150, scale=(0.7, 1.0)),  # 🔥 More aggressive cropping
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),  # 🔥 Increase rotation range
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),  # 🔥 Add hue jitter
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),  # 🔥 Introduce translations
    transforms.GaussianBlur(3),  # 🔥 Simulate telescope noise
    transforms.Normalize((0.5,), (0.5,))
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
# ✅ Dataset Class (Handles Corrupted Files)
class DarkMatterDataset(Dataset):
    def __init__(self, base_dir, mode, transform):
        self.img_paths = []
        self.labels = []
        self.transform = transform

        for label, cls in enumerate(CLASSES):
            cls_dir = os.path.join(base_dir, mode, cls)
            for img_file in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_file)
                self.img_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]

        try:
            img = np.load(img_path)
        except Exception as e:
            print(f"🚨 ERROR: Skipping corrupted file {img_path} | {e}")
            return self.__getitem__((idx + 1) % len(self.img_paths))

        img = cv2.resize(img, IMAGE_SIZE, interpolation=cv2.INTER_AREA)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        img = np.stack([img] * 3, axis=0).astype(np.float32)

        img = self.transform(torch.tensor(img))  # ✅ Transform applied here
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return img, label

In [5]:
train_dataset = DarkMatterDataset(BASE_DIR, 'train', train_transform)
val_dataset = DarkMatterDataset(BASE_DIR, 'val', val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [6]:
class DarkMatterClassifier(nn.Module):
    def __init__(self, num_classes):
        super(DarkMatterClassifier, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

        for param in self.resnet.parameters():
            param.requires_grad = False

        for param in self.resnet.layer3.parameters():
            param.requires_grad = True
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True  

        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(in_features, 512),  # 🔥 Reduce FC layer size
            nn.ReLU(),
            nn.Dropout(0.6),  # 🔥 Increase dropout
            nn.Linear(512, num_classes),  # 🔥 Reduce FC layer size
    )
    def forward(self, x):
        return self.resnet(x)

In [7]:
# ✅ Instantiate Model
model = DarkMatterClassifier(NUM_CLASSES).to(DEVICE)

In [8]:

# ✅ Loss, Optimizer & Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=1, verbose=True)




In [9]:
# ✅ Early Stopping Class
# class EarlyStopping:
#     def __init__(self, patience=3, mode="max", min_delta=0.001):
#         self.patience = patience
#         self.mode = mode
#         self.min_delta = min_delta
#         self.best_score = None
#         self.counter = 0
#         self.early_stop = False

#     def __call__(self, score):
#         if self.best_score is None:
#             self.best_score = score
#         elif (self.mode == "max" and score < self.best_score + self.min_delta) or \
#              (self.mode == "min" and score > self.best_score - self.min_delta):
#             self.counter += 1
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_score = score
#             self.counter = 0  

In [10]:
def train_model(model, train_loader, val_loader, epochs):
    best_auc = 0.0
    patience = 3  # Stop if no improvement for 3 epochs
    counter = 0
    model = model.to(DEVICE)

    for epoch in range(epochs):
        model.train()
        all_labels, all_probs, all_preds = [], [], []

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            probs = torch.softmax(outputs, dim=1).cpu().detach().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().detach().numpy()

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)
            all_preds.extend(preds)

        if len(all_labels) == 0:
            print("🚨 ERROR: No labels collected! Check dataset loading.")
            return  

        all_labels = label_binarize(all_labels, classes=[0, 1, 2])

        train_acc = accuracy_score(np.argmax(all_labels, axis=1), all_preds)
        train_auc = roc_auc_score(all_labels, np.array(all_probs), multi_class='ovr')

        # ✅ Validation Phase
        model.eval()
        val_labels, val_probs = [], []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)

                val_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        if len(val_labels) == 0:
            print("🚨 ERROR: No validation labels! Check validation dataset.")
            return  

        val_labels = label_binarize(val_labels, classes=[0, 1, 2])
        val_auc = roc_auc_score(val_labels, np.array(val_probs), multi_class='ovr')

        print(f"Epoch [{epoch+1}/{epochs}] | Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f}")

        # ✅ Early Stopping Logic
        if val_auc > best_auc:
            best_auc = val_auc
            counter = 0  
            torch.save(model.state_dict(), "best_model.pth")
            print(f"✅ Best model saved with AUC: {best_auc:.4f}")
        else:
            counter += 1
            print(f"⚠️ No improvement for {counter} epochs.")
            if counter >= patience:
                print("🚀 Early stopping activated!")
                break

        scheduler.step(val_auc)  # ✅ Fix: Pass validation AUC as metric


In [None]:
# # ✅ Train the model
train_model(model, train_loader, val_loader, EPOCHS)

Epoch 1/5:   0%|          | 0/2813 [00:00<?, ?it/s]

In [None]:
train_acc = accuracy_score(all_labels.argmax(axis=1), all_preds)
val_acc = accuracy_score(val_labels.argmax(axis=1), np.argmax(val_probs, axis=1))

print(f"Epoch [{epoch+1}/{epochs}] | Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")


NameError: name 'all_labels' is not defined

In [None]:
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import cv2
# from PIL import Image
# from torchvision import models, transforms
# from torch.utils.data import DataLoader, Dataset
# from sklearn.metrics import accuracy_score, roc_auc_score
# from tqdm import tqdm

# # Configuration
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BASE_DIR = r"C:\Project\Enhanced_Dataset\dataset"
# IMAGE_SIZE = (150, 150)
# BATCH_SIZE = 128
# EPOCHS = 30
# CLASSES = ['no', 'sphere', 'vort']
# NUM_CLASSES = len(CLASSES)

# # ImageNet statistics
# RESNET_MEAN = [0.485, 0.456, 0.406]
# RESNET_STD = [0.229, 0.224, 0.225]

# ### 🖼️ Enhanced Data Loading with Physics-based Augmentation
# class DarkMatterDataset(Dataset):
#     def __init__(self, base_dir, mode, transform):
#         self.img_paths = []
#         self.labels = []
#         self.transform = transform

#         # Class balancing through oversampling
#         class_counts = {cls: len(os.listdir(os.path.join(base_dir, mode, cls))) 
#                         for cls in CLASSES}
#         max_count = max(class_counts.values())

#         for label, cls in enumerate(CLASSES):
#             cls_dir = os.path.join(base_dir, mode, cls)
#             files = os.listdir(cls_dir)

#             # Oversample minority classes
#             if len(files) < max_count:
#                 files = np.random.choice(files, size=max_count, replace=True)

#             self.img_paths.extend([os.path.join(cls_dir, f) for f in files])
#             self.labels.extend([label] * len(files))

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         img = np.load(img_path).astype(np.float32)

#         # Normalize the image
#         img = np.clip(img, 0, 1)

#         # Convert to PIL image
#         img = Image.fromarray((img * 255).astype(np.uint8)).convert("RGB")

#         # Apply transform
#         if self.transform:
#             img = self.transform(img)

#         label = torch.tensor(self.labels[idx], dtype=torch.long)
#         return img, label

# ### 🔄 Paper-specified Transformations
# train_transform = transforms.Compose([
#     transforms.Resize(IMAGE_SIZE),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(90),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2),
#     transforms.GaussianBlur(3, sigma=(0.1, 0.5)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=RESNET_MEAN, std=RESNET_STD)
# ])

# val_transform = transforms.Compose([
#     transforms.Resize(IMAGE_SIZE),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=RESNET_MEAN, std=RESNET_STD)
# ])

# ### 🧠 Enhanced ResNet Model with Channel Adaptation
# class DarkMatterClassifier(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()

#         # Channel adapter (learns grayscale→RGB mapping)
#         self.channel_adapter = nn.Sequential(
#             nn.Conv2d(1, 3, kernel_size=7, padding=3),
#             nn.ReLU(),
#             nn.BatchNorm2d(3)
#         )

#         # Pretrained ResNet50
#         self.resnet = models.resnet50(pretrained=True)

#         # Freezing strategy (unfreeze last 2 stages)
#         for name, param in self.resnet.named_parameters():
#             if 'layer3' not in name and 'layer4' not in name and 'fc' not in name:
#                 param.requires_grad = False

#         # Modified classifier head
#         in_features = self.resnet.fc.in_features
#         self.resnet.fc = nn.Sequential(
#             nn.Linear(in_features, 2048),
#             nn.GELU(),
#             nn.Dropout(0.3),
#             nn.Linear(2048, num_classes)
#         )

#     def forward(self, x):
#         x = self.channel_adapter(x)  # Learned channel conversion
#         return self.resnet(x)

# ### 🚀 Training Infrastructure
# def train_model():
#     model = DarkMatterClassifier(NUM_CLASSES).to(DEVICE)
#     optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, mode='max', factor=0.2, patience=3, verbose=True
#     )
#     criterion = nn.CrossEntropyLoss()

#     # Data loaders
#     train_set = DarkMatterDataset(BASE_DIR, 'train', train_transform)
#     val_set = DarkMatterDataset(BASE_DIR, 'val', val_transform)

#     train_loader = DataLoader(
#         train_set, batch_size=BATCH_SIZE, shuffle=True,
#         num_workers=4, pin_memory=torch.cuda.is_available()
#     )
#     val_loader = DataLoader(
#         val_set, batch_size=BATCH_SIZE, num_workers=4,
#         pin_memory=torch.cuda.is_available()
#     )

#     best_auc = 0
#     for epoch in range(EPOCHS):
#         model.train()
#         train_preds, train_labels = [], []

#         for images, labels in tqdm(train_loader, desc=f'Train Epoch {epoch+1}'):
#             images, labels = images.to(DEVICE), labels.to(DEVICE)

#             optimizer.zero_grad()
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#             probs = torch.softmax(outputs, dim=1)
#             train_preds.append(probs.detach().cpu())
#             train_labels.append(labels.cpu())
            
#             torch.cuda.empty_cache()

#         # Validation phase
#         model.eval()
#         val_preds, val_labels = [], []

#         with torch.no_grad():
#             for images, labels in tqdm(val_loader, desc=f'Val Epoch {epoch+1}'):
#                 images = images.to(DEVICE)
#                 outputs = model(images)

#                 val_preds.append(torch.softmax(outputs, dim=1).cpu())
#                 val_labels.append(labels.cpu())

#         # Save best model
#         val_auc = roc_auc_score(val_labels, torch.cat(val_preds).numpy(), multi_class='ovo')
#         if val_auc > best_auc:
#             best_auc = val_auc
#             torch.save(model.state_dict(), 'best_model.pth')
#             print(f"🏆 New best model saved with AUC: {best_auc:.4f}")

# if __name__ == "__main__":
#     train_model()


In [None]:
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import models, transforms
# from torch.utils.data import DataLoader, Dataset
# from sklearn.metrics import accuracy_score, roc_auc_score
# from tqdm import tqdm

# # Configuration
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BASE_DIR = r"C:\Project\Enhanced_Dataset\dataset"
# IMAGE_SIZE = (150, 150)
# BATCH_SIZE = 64  # Optimized batch size
# EPOCHS = 20
# CLASSES = ['no', 'sphere', 'vort']
# NUM_CLASSES = len(CLASSES)

# ### 📊 Data Loader with Efficient Preprocessing
# class DarkMatterDataset(Dataset):
#     def __init__(self, base_dir, mode, transform):
#         self.img_paths = []
#         self.labels = []
#         self.transform = transform

#         for label, cls in enumerate(CLASSES):
#             cls_dir = os.path.join(base_dir, mode, cls)

#             for img_file in os.listdir(cls_dir):
#                 img_path = os.path.join(cls_dir, img_file)
#                 self.img_paths.append(img_path)
#                 self.labels.append(label)

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]

#         # Lazily load the image
#         img = np.load(img_path)

#         # Handle incorrect dimensions
#         if img.shape != IMAGE_SIZE:
#             img = np.resize(img, IMAGE_SIZE)

#         # Expand dimensions and convert to 3 channels
#         img = np.expand_dims(img, axis=0)
#         img = np.repeat(img, 3, axis=0)

#         # Convert to float32
#         img = img.astype(np.float32) / 255.0

#         # Convert to PyTorch tensor
#         img = torch.tensor(img, dtype=torch.float32)
#         img = self.transform(img)

#         label = torch.tensor(self.labels[idx], dtype=torch.long)
#         return img, label


# # Image Transformations with Augmentation
# train_transform = transforms.Compose([
#     transforms.Resize(IMAGE_SIZE),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(15),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2),
#     transforms.Normalize((0.5,), (0.5,))
# ])

# val_transform = transforms.Compose([
#     transforms.Resize(IMAGE_SIZE),
#     transforms.Normalize((0.5,), (0.5,))
# ])

# # Dataloaders
# train_dataset = DarkMatterDataset(BASE_DIR, 'train', train_transform)
# val_dataset = DarkMatterDataset(BASE_DIR, 'val', val_transform)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


# ### 🔥 **ResNet50 Model with Regularization and Dropout**
# class DarkMatterClassifier(nn.Module):
#     def __init__(self, num_classes):
#         super(DarkMatterClassifier, self).__init__()
        
#         # Load pre-trained ResNet50
#         self.resnet = models.resnet50(pretrained=True)
        
#         # Freeze initial layers for transfer learning
#         for param in self.resnet.parameters():
#             param.requires_grad = False
        
#         # Modify final layers
#         in_features = self.resnet.fc.in_features
#         self.resnet.fc = nn.Sequential(
#             nn.Linear(in_features, 1024),
#             nn.ReLU(),
#             nn.Dropout(0.4),                      # Increased dropout for regularization
#             nn.Linear(1024, num_classes)
#         )

#     def forward(self, x):
#         return self.resnet(x)

# # Instantiate the model
# model = DarkMatterClassifier(NUM_CLASSES).to(DEVICE)


In [None]:
# # Loss and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

# # Learning rate scheduler
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer, mode='min', factor=0.2, patience=3, verbose=True
# )


In [None]:
# # Force use of NVIDIA GPU
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# # Send your model and tensors to the GPU
# model = model.to(device)

In [None]:
# from sklearn.metrics import accuracy_score, roc_auc_score

# def train_model(model, train_loader, val_loader, epochs):
#     best_auc = 0.0
#     model = model.to(device)
    
#     for epoch in range(epochs):
#         model.train()
        
#         all_labels = []
#         all_probs = []    # 💡 Store probabilities instead of class indices
#         all_preds = []
        
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
            
#             optimizer.zero_grad()
#             outputs = model(images)
            
#             # Apply softmax to get probabilities
#             probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()
#             preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             all_labels.extend(labels.cpu().numpy())
#             all_probs.extend(probs)          # ✅ Append probabilities
#             all_preds.extend(preds)
        
#         # Calculate metrics
#         train_acc = accuracy_score(all_labels, all_preds)

#         # 🛠️ Use raw probabilities for AUC calculation
#         train_auc = roc_auc_score(
#             all_labels,
#             np.array(all_probs),
#             multi_class='ovr'
#         )

#         # ✅ Validation phase
#         model.eval()
#         val_labels = []
#         val_probs = []
        
#         with torch.no_grad():
#             for images, labels in val_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 outputs = model(images)
                
#                 val_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())
#                 val_labels.extend(labels.cpu().numpy())
        
#         val_auc = roc_auc_score(
#             val_labels,
#             np.array(val_probs),
#             multi_class='ovr'
#         )

#         print(f"Epoch [{epoch+1}/{epochs}] | Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f}")
        
#         # Save the best model
#         if val_auc > best_auc:
#             best_auc = val_auc
#             torch.save(model.state_dict(), "best_model.pth")
#             print(f"✅ Best model saved with AUC: {best_auc:.4f}")


In [None]:
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# # Train the model
# for epoch in range(epochs):
#     model.train()
#     for images, labels in train_loader:
#         images, labels = images.to(DEVICE), labels.to(DEVICE)
        
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
        
#         loss.backward()
#         optimizer.step()


In [None]:
# X_train = np.load('C:\Project\X_train.npy')
# y_train = np.load('C:\Project\y_train.npy')
# X_val = np.load('C:\Project\X_val.npy')
# y_val = np.load('C:\Project\y_val.npy')

# print(X_train.shape, y_train.shape)
# print(X_val.shape, y_val.shape)

In [None]:
# # 1. Dataset Class
# class CustomDataset(Dataset):
#     def __init__(self, X, y, transform=None):
#         self.X = X
#         self.y = y
#         self.transform = transform
        
#     def __len__(self):
#         return len(self.X)
    
#     def __getitem__(self, idx):
#         image = self.X[idx]
#         label = self.y[idx]
        
#         # Convert to tensor and normalize
#         image = torch.from_numpy(image).float() / 255.0
        
#         # Expand grayscale to RGB if needed
#         if len(image.shape) == 2:  # (H,W)
#             image = image.unsqueeze(0).repeat(3,1,1)  # (3,H,W)
#         elif len(image.shape) == 3 and image.shape[0] == 1:  # (1,H,W)
#             image = image.repeat(3,1,1)
            
#         if self.transform:
#             image = self.transform(image)
            
#         return image, torch.tensor(label, dtype=torch.long)


In [None]:
# # 2. Data Augmentation & Transforms
# train_transform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.RandomRotation(15),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# val_transform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

In [None]:
# # 3. Create DataLoaders
# batch_size = 64
# train_dataset = CustomDataset(X_train, y_train, transform=train_transform)
# val_dataset = CustomDataset(X_val, y_val, transform=val_transform)

# train_loader = DataLoader(train_dataset, batch_size=batch_size, 
#                          shuffle=True, num_workers=4, pin_memory=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size,
#                        shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# # 4. Model Definition
# class CustomResNet(nn.Module):
#     def __init__(self, num_classes=3):
#         super().__init__()
#         # Load pretrained ResNet
#         self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        
#         # Freeze all layers first
#         for param in self.resnet.parameters():
#             param.requires_grad = False
            
#         # Replace final layer
#         num_features = self.resnet.fc.in_features
#         self.resnet.fc = nn.Sequential(
#             nn.Linear(num_features, 256),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes))
        
#     def forward(self, x):
#         return self.resnet(x)


In [None]:
# # 5. Training Setup
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = CustomResNet(num_classes=3).to(device)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

In [None]:
# # 6. Training Loop
# def train_model(model, criterion, optimizer, scheduler, num_epochs=20):
#     best_acc = 0.0
#     train_losses, val_losses = [], []
#     train_accs, val_accs = [], []
    
#     for epoch in range(num_epochs):
#         print(f'Epoch {epoch+1}/{num_epochs}')
#         print('-' * 10)
        
#         # Training phase
#         model.train()
#         running_loss = 0.0
#         running_corrects = 0
        
#         for inputs, labels in tqdm(train_loader):
#             inputs = inputs.to(device)
#             labels = labels.to(device)
            
#             optimizer.zero_grad()
            
#             outputs = model(inputs)
#             _, preds = torch.max(outputs, 1)
#             loss = criterion(outputs, labels)
            
#             loss.backward()
#             optimizer.step()
            
#             running_loss += loss.item() * inputs.size(0)
#             running_corrects += torch.sum(preds == labels.data)
            
#         epoch_loss = running_loss / len(train_dataset)
#         epoch_acc = running_corrects.double() / len(train_dataset)
#         train_losses.append(epoch_loss)
#         train_accs.append(epoch_acc)
        
#         # Validation phase
#         model.eval()
#         val_loss = 0.0
#         val_corrects = 0
        
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs = inputs.to(device)
#                 labels = labels.to(device)
                
#                 outputs = model(inputs)
#                 _, preds = torch.max(outputs, 1)
#                 loss = criterion(outputs, labels)
                
#                 val_loss += loss.item() * inputs.size(0)
#                 val_corrects += torch.sum(preds == labels.data)
                
#         val_loss = val_loss / len(val_dataset)
#         val_acc = val_corrects.double() / len(val_dataset)
#         val_losses.append(val_loss)
#         val_accs.append(val_acc)
        
#         scheduler.step(val_loss)
        
#         print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
#         print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
        
#         # Save best model
#         if val_acc > best_acc:
#             best_acc = val_acc
#             torch.save(model.state_dict(), 'best_model.pth')
            
#     return model, train_losses, val_losses, train_accs, val_accs

In [None]:
# # 7. Train the model
# model, train_losses, val_losses, train_accs, val_accs = train_model(
#     model, criterion, optimizer, scheduler, num_epochs=20)