In [1]:
import os
import torch
import numpy as np

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from transformers import ViTForImageClassification, SwinForImageClassification, Swinv2ForImageClassification, DeiTForImageClassification, BeitForImageClassification
from transformers import AutoImageProcessor
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define the dataset path
dataset_dir = '../../img_dataset_phone'

# select models
vit     = False
swin    = False 
swin2   = True
deit    = False
beit    = False

if vit:
    selected_model = 'vit'
elif swin:
    selected_model = 'swin'
elif swin2:
    selected_model = 'swin2'
elif deit:
    selected_model = 'deit'
elif beit:
    selected_model = 'beit'
else:
    raise ValueError('[ERROR] No selected model')

# define models
if swin:
    model_name = 'microsoft/swin-tiny-patch4-window7-224'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = SwinForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  
    )
elif swin2:
    model_name = 'microsoft/swinv2-tiny-patch4-window16-256'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = Swinv2ForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  
    )
elif vit:
    model_name = 'google/vit-base-patch16-224-in21k'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = ViTForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        # ignore_mismatched_sizes=True,
    )
elif deit:
    model_name = 'facebook/deit-base-distilled-patch16-224'
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = DeiTForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        # ignore_mismatched_sizes=True
    )
elif beit:
    model_name = 'microsoft/beit-base-patch16-224-pt22k-ft22k'  
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = BeitForImageClassification.from_pretrained(
        model_name,
        num_labels=36,
        ignore_mismatched_sizes=True  # Add this if necessary
    )
else:
    raise ValueError('[ERROR] Select Your Model')

# Define transformations
if vit or swin or deit or beit:
    print("vit/swin/deit/beit activated")
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
elif swin2:
    print("swin2 activated")
    train_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])
else:
    raise ValueError('[ERROR] Define any transformations')

# Load the dataset
full_dataset = ImageFolder(root=dataset_dir, transform=train_transforms)


Some weights of Swinv2ForImageClassification were not initialized from the model checkpoint at microsoft/swinv2-tiny-patch4-window16-256 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([36, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([36]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


swin2 activated


In [3]:
# Split the dataset per class into train and test sets
from collections import defaultdict
from sklearn.model_selection import train_test_split

# Extract labels (targets) from the dataset
targets = [sample[1] for sample in full_dataset.samples]  # Assuming ImageFolder's samples attribute

# First split: Train and Temp (Val + Test)
train_indices, temp_indices, y_train, y_temp = train_test_split(
    range(len(targets)),
    targets,
    test_size=0.3,  # 30% of the data will go to val+test
    stratify=targets,
    random_state=42
)

# Second split: Validation and Test
val_indices, test_indices, y_val, y_test = train_test_split(
    temp_indices,
    y_temp,
    test_size=0.33,  # 33% of the temp data goes to test, resulting in 20% test of the total data
    stratify=y_temp,
    random_state=42
)

# Create Subset datasets
train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
test_dataset = torch.utils.data.Subset(full_dataset, test_indices)

# Apply test transforms to validation and test datasets
val_dataset.dataset.transform = test_transforms
test_dataset.dataset.transform = test_transforms

# Print dataset sizes to verify
print(f"Train set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}, Test set size: {len(test_dataset)}")

Train set size: 630, Validation set size: 180, Test set size: 90


In [4]:
# Create DataLoaders
train_loader    = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader      = DataLoader(val_dataset, batch_size=16)
test_loader     = DataLoader(test_dataset, batch_size=16)

In [5]:
# Move the model to the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Swinv2ForImageClassification(
  (swinv2): Swinv2Model(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=3, bias=False)
                  )
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96

In [6]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm  # For progress bars

# Set up the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.1, verbose=True)

# Early stopping parameters
best_val_accuracy = 0.0
patience = 25  # Number of epochs to wait before early stopping
epochs_no_improve = 0




In [None]:
# Training loop
num_epochs = 1000

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    # Training
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for batch in tqdm(train_loader, desc="Training", leave=False):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_loss = train_loss / total
    train_accuracy = correct / total
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation", leave=False):
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = val_loss / total
    val_accuracy = correct / total
    
    print(f'Epoch {epoch+1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    
    # Step the scheduler
    scheduler.step(val_accuracy)
    
    # Check for improvement
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), f'best_model_phone_{selected_model}.pth')
        print("Validation accuracy improved, model saved.")
    else:
        epochs_no_improve += 1
        print(f"No improvement in validation accuracy for {epochs_no_improve} epochs.")
    
    # Early stopping
    if epochs_no_improve >= patience:
        print("Early stopping triggered.")
        break


Epoch 1/1000


                                                           

Epoch 1/1000, Train Loss: 3.6132, Train Acc: 0.0206, Val Loss: 3.5823, Val Acc: 0.0278
Validation accuracy improved, model saved.

Epoch 2/1000


                                                           

Epoch 2/1000, Train Loss: 3.5907, Train Acc: 0.0159, Val Loss: 3.5795, Val Acc: 0.0167
No improvement in validation accuracy for 1 epochs.

Epoch 3/1000


                                                           

Epoch 3/1000, Train Loss: 3.5866, Train Acc: 0.0222, Val Loss: 3.5644, Val Acc: 0.0500
Validation accuracy improved, model saved.

Epoch 4/1000


                                                           

Epoch 4/1000, Train Loss: 3.5188, Train Acc: 0.0698, Val Loss: 3.2581, Val Acc: 0.0944
Validation accuracy improved, model saved.

Epoch 5/1000


                                                           

Epoch 5/1000, Train Loss: 3.1704, Train Acc: 0.1111, Val Loss: 2.5052, Val Acc: 0.2944
Validation accuracy improved, model saved.

Epoch 6/1000


                                                           

Epoch 6/1000, Train Loss: 2.3571, Train Acc: 0.2730, Val Loss: 1.5280, Val Acc: 0.5056
Validation accuracy improved, model saved.

Epoch 7/1000


                                                           

Epoch 7/1000, Train Loss: 1.6139, Train Acc: 0.4778, Val Loss: 1.1065, Val Acc: 0.6444
Validation accuracy improved, model saved.

Epoch 8/1000


                                                           

Epoch 8/1000, Train Loss: 1.1111, Train Acc: 0.6286, Val Loss: 0.9155, Val Acc: 0.6611
Validation accuracy improved, model saved.

Epoch 9/1000


                                                           

Epoch 9/1000, Train Loss: 0.6498, Train Acc: 0.7873, Val Loss: 0.7273, Val Acc: 0.7611
Validation accuracy improved, model saved.

Epoch 10/1000


                                                           

Epoch 10/1000, Train Loss: 0.4703, Train Acc: 0.8524, Val Loss: 0.9609, Val Acc: 0.7167
No improvement in validation accuracy for 1 epochs.

Epoch 11/1000


                                                           

Epoch 11/1000, Train Loss: 0.5068, Train Acc: 0.8381, Val Loss: 0.6670, Val Acc: 0.8222
Validation accuracy improved, model saved.

Epoch 12/1000


                                                           

Epoch 12/1000, Train Loss: 0.3699, Train Acc: 0.8873, Val Loss: 0.5661, Val Acc: 0.8278
Validation accuracy improved, model saved.

Epoch 13/1000


                                                           

Epoch 13/1000, Train Loss: 0.3071, Train Acc: 0.9079, Val Loss: 0.5143, Val Acc: 0.8167
No improvement in validation accuracy for 1 epochs.

Epoch 14/1000


                                                           

Epoch 14/1000, Train Loss: 0.2071, Train Acc: 0.9413, Val Loss: 0.4846, Val Acc: 0.8500
Validation accuracy improved, model saved.

Epoch 15/1000


                                                           

Epoch 15/1000, Train Loss: 0.1770, Train Acc: 0.9460, Val Loss: 0.3501, Val Acc: 0.8944
Validation accuracy improved, model saved.

Epoch 16/1000


                                                           

Epoch 16/1000, Train Loss: 0.1657, Train Acc: 0.9619, Val Loss: 0.3565, Val Acc: 0.8833
No improvement in validation accuracy for 1 epochs.

Epoch 17/1000


                                                           

Epoch 17/1000, Train Loss: 0.1130, Train Acc: 0.9730, Val Loss: 0.5051, Val Acc: 0.8556
No improvement in validation accuracy for 2 epochs.

Epoch 18/1000


                                                           

Epoch 18/1000, Train Loss: 0.1685, Train Acc: 0.9381, Val Loss: 0.3653, Val Acc: 0.9000
Validation accuracy improved, model saved.

Epoch 19/1000


                                                           

Epoch 19/1000, Train Loss: 0.1755, Train Acc: 0.9460, Val Loss: 0.3847, Val Acc: 0.8889
No improvement in validation accuracy for 1 epochs.

Epoch 20/1000


                                                           

Epoch 20/1000, Train Loss: 0.1013, Train Acc: 0.9698, Val Loss: 0.3216, Val Acc: 0.9000
No improvement in validation accuracy for 2 epochs.

Epoch 21/1000


Training:  75%|███████▌  | 30/40 [00:29<00:07,  1.28it/s]

### Testing

In [None]:
# Testing the model and generating a classification report
from sklearn.metrics import classification_report

# Load the best model 
model.load_state_dict(torch.load(f'best_model_phone_{selected_model}.pth'))

# Collect all predictions and labels
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        _, predicted = torch.max(outputs.logits, 1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Generate classification report
print(classification_report(all_labels, all_preds, digits=4))

              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000         2
           1     1.0000    1.0000    1.0000         3
           2     1.0000    1.0000    1.0000         3
           3     1.0000    1.0000    1.0000         3
           4     1.0000    1.0000    1.0000         2
           5     1.0000    1.0000    1.0000         2
           6     1.0000    1.0000    1.0000         3
           7     1.0000    0.3333    0.5000         3
           8     1.0000    0.5000    0.6667         2
           9     0.6667    1.0000    0.8000         2
          10     1.0000    1.0000    1.0000         3
          11     1.0000    1.0000    1.0000         3
          12     1.0000    1.0000    1.0000         2
          13     0.6667    1.0000    0.8000         2
          14     1.0000    1.0000    1.0000         3
          15     1.0000    1.0000    1.0000         2
          16     1.0000    1.0000    1.0000         2
          17     1.0000    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
