In [1]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
from angle_estimation_model import AngleDataset, AngleEstimationModel, train_epoch, evaluate_model, mean_shift, split_dataset
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import math
from sincos_model import SinCosDataset, SinCosModel, SinCosResModel
from ensemble_model import AngleEnsemble, EnsembleDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using", device)

Using mps


In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x12880e670>

In [4]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
IMG_SIZE = 224


In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    # transforms.Grayscale(num_output_channels=3),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    # transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    # transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
data_dir = "/Users/vlad.sarm/Documents/sausage_rotation_estimation/data"  # Update this path to your dataset
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
raw_dir = os.path.join(data_dir, "by_degrees")

In [7]:
split_dataset(raw_dir, data_dir, val_ratio=0.1)

Dataset split complete: 5097 training images, 367 validation images


(5097, 367)

In [8]:
train_dataset = SinCosDataset(
    root_dir=train_dir,
    transform=train_transforms,
)

val_dataset = SinCosDataset(
    root_dir=val_dir,
    transform=val_transforms,
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 5097
Validation samples: 367


In [9]:
# Initialize model
N, M = 24, 4
model_von = AngleEstimationModel(N=N, M=M, feature_extract=False)
checkpoint_von = torch.load("/Users/vlad.sarm/Documents/sausage_rotation_estimation/angle_recognition/rgb2_N-24_M-4/MAE-2.50_DEV-11.42_EPOCH-8.pth", map_location=device)
model_von.load_state_dict(checkpoint_von)

model_sin = SinCosResModel()  # Set feature_extract=False to fine-tune the whole model
checkpoint_sin = torch.load("/Users/vlad.sarm/Documents/sausage_rotation_estimation/angle_recognition/sincos_/MAE-2.04_EPOCH-29.pth", map_location=device)
model_sin.load_state_dict(checkpoint_sin)

model = AngleEnsemble(model_von, model_sin)
model = model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [10]:
def mae_ang(pred_deg, true_deg):
    """
    pred_vec, true_vec: (B,2) unit vectors
    Returns the mean absolute angular error in degrees.
    """
    # wrap-aware difference in (-180,180]
    diff = (pred_deg - true_deg + 180) % 360 - 180
    # Mean absolute error
    return diff.abs().mean().item()

In [11]:
def step(dataloader, train: bool = False):
    model.train(train)
    total_loss, n = 0.0, 0
    total_mae = 0
    for imgs, vec in dataloader:
        imgs, vec = imgs.to(device), vec.to(device)
        
        pred = model(imgs)
        if train:
            optimizer.zero_grad()
        loss = criterion(pred, vec)
        if train:
            loss.backward(); optimizer.step()
        bs = imgs.size(0)
        total_loss += loss.item() * bs
        n += bs

        batch_mae = mae_ang(pred.cpu(), vec.cpu())
        total_mae += batch_mae * imgs.size(0)
    mae = total_mae / n
    loss = total_loss / n
    return loss, mae

In [12]:
print(model)

AngleEnsemble(
  (model_von): AngleEstimationModel(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (r

In [13]:
TRAIN = True  # Set to False to skip training and load the model directly

if TRAIN:
    best_val_mae = float('inf')
    best_epoch = 0
    best_model_state = None

    val_mae_not_improved = 0
    patience = 4  # Number of epochs to wait for improvement before stopping

    print("Starting training...")
    for epoch in tqdm(range(NUM_EPOCHS)):
        train_loss, train_mae = step(train_loader, train=True)
        val_loss, val_mae = step(train_loader, train=False)

        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}")

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch
            best_model_state = model.state_dict()
            val_mae_not_improved = 0

            print("Saving model...")
            save_path = f"ensemble"
            os.makedirs(save_path, exist_ok=True)
            prev = os.listdir(save_path)
            for file in prev:
                if file.endswith(".pth"):
                    os.remove(os.path.join(save_path, file))
            torch.save(model.state_dict(), f"{save_path}/MAE-{val_mae:.2f}_EPOCH-{epoch+1}.pth")
        else:
            val_mae_not_improved += 1
            if val_mae_not_improved >= patience:
                print(f"Early stopping at epoch {epoch+1} with best validation MAE: {best_val_mae:.4f} at epoch {best_epoch+1}")
                break
else:
    checkpoint = torch.load("angle_recognition/sincos_N-24_M-4/MAE-17.02_EPOCH-10.pth", map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()


Starting training...


  0%|          | 0/30 [00:00<?, ?it/s]

Epoch [1/30], Train Loss: 0.8689, Train MAE: 0.7097
Val Loss: 0.7287, Val MAE: 0.5880
Saving model...


  3%|▎         | 1/30 [04:40<2:15:28, 280.31s/it]

Epoch [2/30], Train Loss: 0.6527, Train MAE: 0.5731
Val Loss: 0.4881, Val MAE: 0.4469
Saving model...


  7%|▋         | 2/30 [11:28<2:40:37, 344.18s/it]


KeyboardInterrupt: 