In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm  # For BiT-M model

from tqdm import tqdm  # Import tqdm for progress bar

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
# Hyperparameters
batch_size = 101  # Reduced batch size for efficient training
num_classes = 101  # Food-101 has 101 classes
learning_rate = 1e-4  # Lower LR for fine-tuning
num_epochs = 5  # More epochs for better convergence

In [4]:
# Data Augmentation and Normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load Food-101 dataset
train_dataset = datasets.Food101(root="data", split="train", transform=transform, download=True)
test_dataset = datasets.Food101(root="data", split="test", transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# Load BiT-M ResNet50x1 Model
teacher_model = timm.create_model("resnetv2_50x1_bitm", pretrained=True, num_classes=num_classes)
teacher_model = teacher_model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(teacher_model.parameters(), lr=learning_rate, weight_decay=1e-4)

  model = create_fn(


In [6]:
# Training Function with Batch-Wise Progress Tracking
def train_model():
    teacher_model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        # tqdm progress bar for batch processing
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (images, labels) in progress_bar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = teacher_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update tqdm description with latest loss
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%\n")

In [10]:
# Evaluation Function with Progress Bar
def evaluate_model():
    teacher_model.eval()
    correct = 0
    total = 0

    # tqdm progress bar for evaluation
    progress_bar = tqdm(test_loader, total=len(test_loader), desc="Evaluating", leave=False)

    with torch.no_grad():
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = teacher_model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_accuracy = 100 * correct / total
    print(f"\nTest Accuracy: {test_accuracy:.2f}%")

In [8]:
# Run Training and Evaluation
train_model()

Epoch 1/5:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 1/5: 100%|██████████| 750/750 [5:12:41<00:00, 25.02s/it, loss=0.8707]     


Epoch [1/5], Loss: 1.3716, Accuracy: 64.71%



Epoch 2/5: 100%|██████████| 750/750 [4:03:48<00:00, 19.51s/it, loss=0.5259]  


Epoch [2/5], Loss: 0.7976, Accuracy: 78.05%



Epoch 3/5: 100%|██████████| 750/750 [4:04:26<00:00, 19.56s/it, loss=0.6051]  


Epoch [3/5], Loss: 0.6090, Accuracy: 82.89%



Epoch 4/5: 100%|██████████| 750/750 [4:05:08<00:00, 19.61s/it, loss=0.5065]  


Epoch [4/5], Loss: 0.4922, Accuracy: 85.82%



Epoch 5/5: 100%|██████████| 750/750 [4:04:19<00:00, 19.55s/it, loss=0.2995]  


Epoch [5/5], Loss: 0.3974, Accuracy: 88.22%



In [11]:
evaluate_model()

                                                             


Test Accuracy: 81.72%




In [12]:
torch.save(teacher_model.state_dict(), "teacher models/bitm_resnet50x1_F101.pth")
print("Teacher Model Saved!")

Teacher Model Saved!
