In [None]:
!pip install wandb -q

In [1]:
import os
import pandas as pd
import shutil
from tqdm import tqdm

import wandb
from kaggle_secrets import UserSecretsClient

In [2]:
# Log in to W&B using the API key stored in Kaggle Secrets
try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=wandb_api_key)
    WANDB_ACTIVE = True
    print("Successfully logged into W&B.")
except:
    WANDB_ACTIVE = False
    print("WARNING: Could not log into W&B. Training will continue without tracking.")


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjnikiema[0m ([33mimg_seg[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Successfully logged into W&B.


In [None]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [3]:
def prepare_validation_set_from_kaggle(input_dir, output_dir):
    """
    Prepares the ImageNet validation set from the pre-extracted Kaggle directory structure.
    
    Args:
        input_dir (str): The root path to the Kaggle input data.
        output_dir (str): The path to the writeable output directory (/kaggle/working/).
    """
    
    print("--- Preparing ONLY the validation set ---")
    
    # Define paths
    val_images_path = os.path.join(input_dir, 'ILSVRC/Data/CLS-LOC/val')
    solution_file_path = os.path.join(input_dir, 'LOC_val_solution.csv')
    
    # The new sorted validation directory will be created in our workspace
    sorted_val_dir = os.path.join(output_dir, 'val_sorted')
    os.makedirs(sorted_val_dir, exist_ok=True)
    
    print(f"Reading solution file from: {solution_file_path}")
    # Read the CSV into a pandas DataFrame
    df = pd.read_csv(solution_file_path)
    
    # The PredictionString contains the class ID (e.g., 'n01440764 1 2 3 4')
    # We just need the first part.
    df['class_id'] = df['PredictionString'].apply(lambda x: x.split(' ')[0])
    
    print(f"Found {len(df)} images to sort.")
    print(f"Copying and sorting images from {val_images_path} to {sorted_val_dir}...")

    # Loop through the dataframe and copy each file to its new class directory
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        image_id = row['ImageId']
        class_id = row['class_id']
        
        # Create the destination class folder if it doesn't exist
        dest_class_dir = os.path.join(sorted_val_dir, class_id)
        os.makedirs(dest_class_dir, exist_ok=True)
        
        # Construct source and destination paths
        src_path = os.path.join(val_images_path, image_id + '.JPEG')
        dest_path = os.path.join(dest_class_dir, image_id + '.JPEG')
        
        # Copy the file
        shutil.copyfile(src_path, dest_path)
        
    print("\n--- Validation set preparation complete! ---")
    print(f"Sorted validation data is ready in: {sorted_val_dir}")
    return sorted_val_dir

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on device: {device}")

Training on device: cuda


# I. Building Blocks of DenseNet

We'll start by implementing the core components of the DenseNet architecture.

## 1. DenseNet Simple Layer

The simple layer in a DenseNet consists of a Batch Normalization layer, a ReLU activation function, and a 3x3 Convolutional layer.

In [6]:
class DenseNetSimpleLayer(nn.Module):
  def __init__(self, in_channels, growth_rate):
      """
      Initializes the DenseNet Simple Layer.

      Args:
          in_channels (int): Number of input channels.
          growth_rate (int): Number of output channels (k in the paper).
      """
      super(DenseNetSimpleLayer, self).__init__()
      self.bn1 = nn.BatchNorm2d(in_channels)
      self.relu1 = nn.ReLU(inplace=True)
      self.conv1 = nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

  def forward(self, x):
      """
      Forward pass of the DenseNet Simple Layer.

      Args:
          x (torch.Tensor): Input tensor.

      Returns:
          torch.Tensor: Output tensor.
      """
      out = self.conv1(self.relu1(self.bn1(x)))
      out = torch.cat([x, out], 1)
      return out

## 2. DenseNet Bottleneck Layer

The bottleneck layer is a more computationally efficient version of the simple layer. It introduces a 1x1 convolution to reduce the number of feature maps before the more expensive 3x3 convolution. The 1x1 convolution produces 4 * growth_rate feature maps.

In [7]:
class DenseNetBottleneckLayer(nn.Module):
  def __init__(self, in_channels, growth_rate, dropout_rate=0):
      """
      Initializes the DenseNet Bottleneck Layer.

      Args:
          in_channels (int): Number of input channels.
          growth_rate (int): Number of output channels for the 3x3 convolution.
      """
      super(DenseNetBottleneckLayer, self).__init__()
      inter_channels = 4 * growth_rate
      self.dropout_rate = dropout_rate
      self.bn1 = nn.BatchNorm2d(in_channels)
      self.relu1 = nn.ReLU(inplace=True)
      self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=1, bias=False)

      self.bn2 = nn.BatchNorm2d(inter_channels)
      self.relu2 = nn.ReLU(inplace=True)
      self.conv2 = nn.Conv2d(inter_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
      # Add dropout layer if specified
      if self.dropout_rate > 0:
          self.dropout = nn.Dropout2d(p=self.dropout_rate)
          
  def forward(self, x):
      out = self.conv1(self.relu1(self.bn1(x)))
      out = self.conv2(self.relu2(self.bn2(out)))
      
      # Apply dropout before concatenation
      if self.dropout_rate > 0:
        out = self.dropout(out)
          
      out = torch.cat([x, out], 1)
      return out


## 3. Transition Layer

The transition layer connects two dense blocks. It consists of a Batch Normalization layer, a 1x1 Convolutional layer to reduce the number of channels (compression), and an Average Pooling layer to reduce the spatial dimensions.

In [8]:
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Initializes the Transition Layer.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(TransitionLayer, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        """
        Forward pass of the Transition Layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.avg_pool(out)
        return out

# II. Assembling the Full DenseNet Model

Now we will combine these building blocks to create the complete DenseNet architecture.

In [9]:
class DenseNet(nn.Module):
  def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, dropout_rate=0, init_weights=True, dataset_used="cifar"):
    """
    Initializes the DenseNet model.

    Args:
        block (nn.Module): The type of dense layer to use (Simple or Bottleneck).
        nblocks (list of int): The number of layers in each dense block.
        growth_rate (int): The growth rate (k).
        reduction (float): The compression factor for the transition layers.
        num_classes (int): The number of output classes.
        init_weights (bool): Whether to initialize the weights.
        dataset_used (str): The dataset used for training.
    """
    super(DenseNet, self).__init__()
    self.growth_rate = growth_rate
    num_planes = 2 * growth_rate
    self.dropout_rate = dropout_rate
    self.dataset_used = dataset_used

    if self.dataset_used == "cifar":
      # Initial convolution for CIFAR-X
      self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
    else:
      # # Initial convolution for ImageNet
      self.conv1 = nn.Sequential(
            nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(2 * growth_rate),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

    # First Dense Block
    self.dense1 = self._make_dense_block(block, num_planes, nblocks[0])
    num_planes += nblocks[0] * growth_rate
    out_planes = int(num_planes * reduction)
    self.trans1 = TransitionLayer(num_planes, out_planes)
    num_planes = out_planes

    # Second Dense Block
    self.dense2 = self._make_dense_block(block, num_planes, nblocks[1])
    num_planes += nblocks[1] * growth_rate
    out_planes = int(num_planes * reduction)
    self.trans2 = TransitionLayer(num_planes, out_planes)
    num_planes = out_planes

    # Third Dense Block
    self.dense3 = self._make_dense_block(block, num_planes, nblocks[2])
    num_planes += nblocks[2] * growth_rate
      
    if self.dataset_used != "cifar":
        out_planes = int(num_planes * reduction)
        self.trans3 = TransitionLayer(num_planes, out_planes)
        num_planes = out_planes
    
        # Fourth Dense Block
        self.dense4 = self._make_dense_block(block, num_planes, nblocks[3])
        num_planes += nblocks[3] * growth_rate

    # Final layers
    self.bn = nn.BatchNorm2d(num_planes)
    self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    self.linear = nn.Linear(num_planes, num_classes)

    if init_weights:
        self._initialize_weights()

  def _make_dense_block(self, block, in_planes, nblock):
      layers = []
      for _ in range(nblock):
          layers.append(block(in_planes, self.growth_rate, self.dropout_rate))
          in_planes += self.growth_rate
      return nn.Sequential(*layers)

  def _initialize_weights(self):
      for m in self.modules():
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
              if m.bias is not None:
                  nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.BatchNorm2d):
              nn.init.constant_(m.weight, 1)
              nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.Linear):
              nn.init.constant_(m.bias, 0)

  def forward(self, x):
      out = self.conv1(x)
      out = self.trans1(self.dense1(out))
      out = self.trans2(self.dense2(out))
      if self.dataset_used != "cifar":
          out = self.trans3(self.dense3(out))
          out = self.dense4(out)
      else:
          out = self.dense3(out)
      out = self.avg_pool(F.relu(self.bn(out)))
      out = torch.flatten(out, 1)
      out = self.linear(out)
      return out

# III. Training and Evaluation

In [10]:
def Densenet_cifar(k=12, dropout_rate=0, num_classes=10):
    return DenseNet(DenseNetBottleneckLayer, [16, 16, 16], growth_rate=k, dropout_rate=dropout_rate, num_classes=num_classes)

In [11]:
def DenseNet121():
    return DenseNet(DenseNetBottleneckLayer, [6,12,24,16], growth_rate=32, dataset_used="imagenet")

def DenseNet169():
    return DenseNet(DenseNetBottleneckLayer, [6,12,32,32], growth_rate=32, dataset_used="imagenet")

def DenseNet201():
    return DenseNet(DenseNetBottleneckLayer, [6,12,48,32], growth_rate=32, dataset_used="imagenet")

def DenseNet161():
    return DenseNet(DenseNetBottleneckLayer, [6,12,36,24], growth_rate=48, dataset_used="imagenet")

## 1. For CIFAR-10 & 100 with simple data augmentation

In [12]:
transform_train_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [13]:
# TRAINING FUNCTION

def train(epoch, model, trainloader, optimizer, criterion, device):
        model.train()
        print(f'\nEpoch: {epoch} | LR: {optimizer.param_groups[0]["lr"]:.5f}')
        running_loss = 0.0
        correct = 0
        total = 0
        progress_bar = tqdm(enumerate(trainloader), total=len(trainloader))
        for i, (inputs, targets) in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if epoch % 10 == 0:
                progress_bar.set_description(f'Loss: {running_loss/(i+1):.3f} | Acc: {100.*correct/total:.3f}%')


        avg_loss = running_loss / len(trainloader)
        avg_acc = 100. * correct / total
        
        return avg_loss, avg_acc


# EVALUATION FUNCTION

def evaluate_cifar(epoch, model, testloader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    avg_loss = test_loss / len(testloader)
    avg_acc = 100. * correct / total
    error_rate = 100. - avg_acc
    
    if epoch % 10 == 0:
        print(f"--- Epoch {epoch} Test Results ---")
        print(f"Accuracy: {avg_acc:.2f}% | Error Rate: {error_rate:.2f}% | Loss: {avg_loss:.2f}%")
        print("--------------------------")

    return avg_loss, avg_acc

### CIFAR-10

In [14]:
EPOCHS = 300
# DataLoaders
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train_cifar)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test_cifar)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

for k in [12, 24]:
    print("#"*50)
    print(f"--- Starting training for k = {k} ---")

    # Manage the run ID for robust resuming
    run_id = None
    run_id_path = f"/kaggle/working/wandb_run_id_k{k}.txt"
    if os.path.exists(run_id_path):
        with open(run_id_path, 'r') as f:
            run_id = f.read()
        print(f"Found existing run ID for k={k}: {run_id}. Attempting to resume.")

    # Initialize a new W&B run
    if WANDB_ACTIVE:
        run = wandb.init(
            project="densenet-cifar10",
            name=f"densenet_k{k}",
            config={
                "growth_rate": k, "epochs": EPOCHS, "batch_size": 128,
                "learning_rate": 0.1, "optimizer": "SGD_Nesterov"
            },
            id=run_id,
            resume="allow",
            job_type="train"
        )
    
    # Save the new run ID if this is a fresh run
    if WANDB_ACTIVE and not os.path.exists(run_id_path):
        with open(run_id_path, 'w') as f:
            f.write(run.id)
        print(f"Created new run with ID: {run.id}. Saved ID for future resume.")

    # Create the model, optimizer, and scheduler
    model = Densenet_cifar(k=k).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
    milestones = [int(EPOCHS * 0.5), int(EPOCHS * 0.75)]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    start_epoch = 0

    # Attempt to resume from a W&B checkpoint if the run was successfully resumed
    if WANDB_ACTIVE and run.resumed:
        try:
            print("Run successfully resumed. Attempting to load checkpoint from W&B Artifacts...")
            # Fetch the latest version of the artifact
            artifact = run.use_artifact(f'densenet-cifar-k{k}:latest')
            # Download the checkpoint file
            artifact_dir = artifact.download(root="/kaggle/working/artifacts")
            checkpoint_file = os.path.join(artifact_dir, os.listdir(artifact_dir)[0])

            # Load the state of the model, optimizer, etc.
            checkpoint = torch.load(checkpoint_file)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Checkpoint loaded. Resuming training at epoch {start_epoch}.")
        except Exception as e:
            print(f"Could not load checkpoint. Starting from scratch. Error: {e}")

    # Main Training Loop
    for epoch in range(start_epoch, EPOCHS):
        # Training and evaluation
        train_loss, train_acc = train(epoch, model, trainloader, optimizer, criterion, device)
        test_loss, test_acc = evaluate_cifar(epoch, model, testloader, criterion, device)
        scheduler.step()

        # Log metrics to W&B
        if WANDB_ACTIVE:
            wandb.log({
                "epoch": epoch, "train_loss": train_loss, "train_accuracy": train_acc,
                "test_loss": test_loss, "test_accuracy": test_acc,
                "learning_rate": scheduler.get_last_lr()[0]
            })

        # Save a checkpoint to W&B Artifacts periodically (e.g., every 25 epochs)
        if WANDB_ACTIVE and ((epoch + 1) % 25 == 0 or epoch == EPOCHS - 1):
            # First, save the file locally in the Kaggle environment
            local_path = f"/kaggle/working/checkpoint_k{k}_epoch{epoch}.pth"
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
            }, local_path)

            # Create a W&B Artifact to version the model
            artifact = wandb.Artifact(
                name=f'densenet-cifar-k{k}', # Name for the artifact collection
                type='model',
                description=f'Checkpoint for Densenet k={k} at epoch {epoch}'
            )
            artifact.add_file(local_path)

            # Upload the artifact to W&B servers
            run.log_artifact(artifact)
            print(f"Checkpoint for epoch {epoch} saved to W&B Artifacts.")

    # Finish the W&B run
    if WANDB_ACTIVE:
        run.finish()

    print(f"--- Training finished for k = {k} ---")

##################################################
--- Starting training for k = 12 ---



Epoch: 0 | LR: 0.10000


Loss: 1.456 | Acc: 46.042%: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s]


--- Epoch 0 Test Results ---
Accuracy: 52.48% | Error Rate: 47.52% | Loss: 1.51%
--------------------------

Epoch: 1 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 2 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 3 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 4 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.86it/s]



Epoch: 5 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 6 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 7 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.85it/s]



Epoch: 8 | LR: 0.10000


100%|██████████| 391/391 [01:06<00:00,  5.84it/s]



Epoch: 9 | LR: 0.10000


100%|██████████| 391/391 [01:07<00:00,  5.83it/s]



Epoch: 10 | LR: 0.10000


Loss: 0.339 | Acc: 88.275%:  33%|███▎      | 129/391 [00:22<00:45,  5.74it/s]


KeyboardInterrupt: 

## 2. For ImageNet

In [None]:
# DATA LOADING AND TRANSFORMATION

# ImageNet statistics
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Data augmentation for the training set
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

# Transformation for the validation set
transform_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

In [None]:
# TRAINING FUNCTION
def train_imagenet(epoch, model, trainloader, optimizer, criterion, device):
    print(f'\nEpoch: {epoch}')
    model.train()
    train_loss = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total += targets.size(0)

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} | Batch {batch_idx}/{len(trainloader)} | Loss: {train_loss/(batch_idx+1):.3f}')

# EVALUATION FUNCTION (WITH TOP-1 AND TOP-5)
def evaluate_imagenet(model, valloader, criterion, device):
    model.eval()
    val_loss = 0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_loss += loss.item()

            # Calculate Top-1 and Top-5 accuracy
            _, pred = outputs.topk(5, 1, largest=True, sorted=True)
            pred = pred.t()
            correct = pred.eq(targets.view(1, -1).expand_as(pred))

            correct_top1 += correct[:1].reshape(-1).float().sum(0, keepdim=True).item()
            correct_top5 += correct[:5].reshape(-1).float().sum(0, keepdim=True).item()
            total += targets.size(0)

    # Calculate final accuracies
    top1_acc = 100. * correct_top1 / total
    top5_acc = 100. * correct_top5 / total

    print("\n--- Validation Results ---")
    print(f"Average Loss: {val_loss / len(valloader):.4f}")
    print(f"Top-1 Accuracy: {top1_acc:.2f}% ({int(correct_top1)}/{total})")
    print(f"Top-5 Accuracy: {top5_acc:.2f}% ({int(correct_top5)}/{total})")
    print("--------------------------\n")

In [None]:
INPUT_DIR = "/kaggle/input/imagenet-object-localization-challenge"
OUTPUT_DIR = "/kaggle/working/"

# This function will create and return the path to the sorted validation set
sorted_val_path = prepare_validation_set_from_kaggle(input_dir=INPUT_DIR, output_dir=OUTPUT_DIR)

# The training directory points DIRECTLY to the read-only input data. No copy needed!
train_dir = os.path.join(INPUT_DIR, 'ILSVRC/Data/CLS-LOC/train')
# The validation directory points to our newly created sorted folder
val_dir = sorted_val_path

print(f"\nUsing Training data from: {train_dir}")
print(f"Using Validation data from: {val_dir}")

In [None]:
# DataLoaders
train_dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

val_dataset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_val)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4)

# Use our from-scratch model, adjusting for the number of classes in the dataset
num_classes = len(train_dataset.classes)
model = DenseNet121(num_classes=num_classes).to(device)
print(f"Custom DenseNet-121 model for ImageNet created successfully with {num_classes} classes.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training Loop
best_acc = 0.0
for epoch in range(90):
  train(epoch, model, trainloader, optimizer, criterion, device)
  evaluate_imagenet(model, valloader, criterion, device)
  scheduler.step()
  # if acc > best_acc:
  #     print("Saving new best model...")
  #     best_acc = acc
  #     torch.save(model.state_dict(), 'densenet_imagenet_scratch_best.pth')