### Necessary imports

In [None]:
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from torchvision.models import efficientnet_v2_s, efficientnet_v2_m

### Model choices

In [None]:
# Option 1: EfficientNetV2-Small 
class FoodClassifierV2S(nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        self.backbone = efficientnet_v2_s(pretrained=True)
        self.backbone.classifier[1] = nn.Linear(
            self.backbone.classifier[1].in_features, 
            num_classes
        )
    
    def forward(self, x):
        return self.backbone(x)

# Option 2: EfficientNetV2-Medium 
class FoodClassifierV2M(nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        self.backbone = efficientnet_v2_m(pretrained=True)
        self.backbone.classifier[1] = nn.Linear(
            self.backbone.classifier[1].in_features, 
            num_classes
        )
    
    def forward(self, x):
        return self.backbone(x)

# Option 3: Original EfficientNet-B4 (for comparison)
class FoodClassifierB4(nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b4')
        self.backbone._fc = nn.Linear(
            self.backbone._fc.in_features, 
            num_classes
        )
    
    def forward(self, x):
        return self.backbone(x)

### Training configurations

In [None]:
# For EfficientNetV2
v2_transform = transforms.Compose([
    transforms.Resize((384, 384)),  # V2-S/M typically use 384
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# For EfficientNet-B4
b4_transform = transforms.Compose([
    transforms.Resize((380, 380)),  # B4 uses 380
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])