In [1]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
from angle_estimation_model import AngleDataset, AngleEstimationModel, train_epoch, evaluate_model, mean_shift, split_dataset
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import math
from sincos_model import SinCosDataset, SinCosModel, SinCosResModel

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using", device)

Using mps


In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x12000a670>

In [4]:
N = 24           # Number of orientations per task
M = 4               # Number of tasks (discretization)
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
IMG_SIZE = 224


In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    # transforms.Grayscale(num_output_channels=3),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    # transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    # transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
data_dir = "/Users/vlad.sarm/Documents/sausage_rotation_estimation/data"  # Update this path to your dataset
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
raw_dir = os.path.join(data_dir, "by_degrees")

In [7]:
split_dataset(raw_dir, data_dir, val_ratio=0.1)

Dataset split complete: 5097 training images, 367 validation images


(5097, 367)

In [8]:
train_dataset = SinCosDataset(
    root_dir=train_dir,
    transform=train_transforms,
)

val_dataset = SinCosDataset(
    root_dir=val_dir,
    transform=val_transforms,
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 5097
Validation samples: 367


In [9]:
# Initialize model
model = SinCosResModel()  # Set feature_extract=False to fine-tune the whole model
model = model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [10]:
def vec2deg(v: torch.Tensor) -> torch.Tensor:
    """v is (N,2) on CPU; returns angles in [0,360)"""
    ang = torch.atan2(v[:,1], v[:,0])   # (-π, π]
    deg = torch.rad2deg(ang) % 360
    return deg

def mae_ang(pred_vec, true_vec):
    """
    pred_vec, true_vec: (B,2) unit vectors
    Returns the mean absolute angular error in degrees.
    """
    # convert to degrees
    pred_deg = vec2deg(pred_vec)
    true_deg = vec2deg(true_vec)
    # wrap-aware difference in (-180,180]
    diff = (pred_deg - true_deg + 180) % 360 - 180
    # Mean absolute error
    return diff.abs().mean().item()

In [11]:
def step(dataloader, train: bool = False):
    model.train(train)
    total_loss, n = 0.0, 0
    total_mae = 0
    for imgs, vec in dataloader:
        imgs, vec = imgs.to(device), vec.to(device)
        
        pred = model(imgs)
        if train:
            optimizer.zero_grad()
        loss = criterion(pred, vec)
        if train:
            loss.backward(); optimizer.step()
        bs = imgs.size(0)
        total_loss += loss.item() * bs
        n += bs

        batch_mae = mae_ang(pred.cpu(), vec.cpu())
        total_mae += batch_mae * imgs.size(0)
    mae = total_mae / n
    loss = total_loss / n
    return loss, mae

In [12]:
print(model)

SinCosResModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
  

In [13]:
TRAIN = True  # Set to False to skip training and load the model directly

if TRAIN:
    best_val_mae = float('inf')
    best_epoch = 0
    best_model_state = None

    val_mae_not_improved = 0
    patience = 4  # Number of epochs to wait for improvement before stopping

    print("Starting training...")
    for epoch in tqdm(range(NUM_EPOCHS)):
        train_loss, train_mae = step(train_loader, train=True)
        val_loss, val_mae = step(train_loader, train=False)

        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}")

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_epoch = epoch
            best_model_state = model.state_dict()
            val_mae_not_improved = 0

            print("Saving model...")
            save_path = f"sincos_"
            os.makedirs(save_path, exist_ok=True)
            prev = os.listdir(save_path)
            for file in prev:
                if file.endswith(".pth"):
                    os.remove(os.path.join(save_path, file))
            torch.save(model.state_dict(), f"{save_path}/MAE-{val_mae:.2f}_EPOCH-{epoch+1}.pth")
        else:
            val_mae_not_improved += 1
            if val_mae_not_improved >= patience:
                print(f"Early stopping at epoch {epoch+1} with best validation MAE: {best_val_mae:.4f} at epoch {best_epoch+1}")
                break
else:
    checkpoint = torch.load("angle_recognition/sincos_N-24_M-4/MAE-17.02_EPOCH-10.pth", map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()


Starting training...


  0%|          | 0/30 [00:00<?, ?it/s]

Epoch [1/30], Train Loss: 0.2044, Train MAE: 28.0671
Val Loss: 0.0260, Val MAE: 10.0357
Saving model...


  3%|▎         | 1/30 [01:12<35:10, 72.77s/it]

Epoch [2/30], Train Loss: 0.0331, Train MAE: 11.3716
Val Loss: 0.0141, Val MAE: 7.4958
Saving model...


  7%|▋         | 2/30 [02:25<34:03, 72.97s/it]

Epoch [3/30], Train Loss: 0.0215, Train MAE: 9.2257
Val Loss: 0.0095, Val MAE: 6.0956
Saving model...


 10%|█         | 3/30 [03:43<33:42, 74.90s/it]

Epoch [4/30], Train Loss: 0.0162, Train MAE: 7.9304
Val Loss: 0.0081, Val MAE: 5.6900
Saving model...


 13%|█▎        | 4/30 [05:34<38:42, 89.31s/it]

Epoch [5/30], Train Loss: 0.0124, Train MAE: 7.0464
Val Loss: 0.0059, Val MAE: 4.7980
Saving model...


 17%|█▋        | 5/30 [06:56<36:02, 86.51s/it]

Epoch [6/30], Train Loss: 0.0105, Train MAE: 6.4817
Val Loss: 0.0049, Val MAE: 4.3975
Saving model...


 20%|██        | 6/30 [08:09<32:53, 82.22s/it]

Epoch [7/30], Train Loss: 0.0093, Train MAE: 6.1433
Val Loss: 0.0040, Val MAE: 4.0174
Saving model...


 23%|██▎       | 7/30 [09:22<30:20, 79.14s/it]

Epoch [8/30], Train Loss: 0.0079, Train MAE: 5.6450
Val Loss: 0.0036, Val MAE: 3.7900
Saving model...


 27%|██▋       | 8/30 [10:34<28:12, 76.93s/it]

Epoch [9/30], Train Loss: 0.0069, Train MAE: 5.2697
Val Loss: 0.0032, Val MAE: 3.5389
Saving model...


 30%|███       | 9/30 [11:46<26:23, 75.41s/it]

Epoch [10/30], Train Loss: 0.0065, Train MAE: 5.1322
Val Loss: 0.0028, Val MAE: 3.3488
Saving model...


 33%|███▎      | 10/30 [12:59<24:51, 74.55s/it]

Epoch [11/30], Train Loss: 0.0062, Train MAE: 4.9894
Val Loss: 0.0027, Val MAE: 3.2434
Saving model...


 37%|███▋      | 11/30 [14:11<23:23, 73.85s/it]

Epoch [12/30], Train Loss: 0.0052, Train MAE: 4.6401
Val Loss: 0.0024, Val MAE: 3.0842
Saving model...


 40%|████      | 12/30 [15:23<21:57, 73.22s/it]

Epoch [13/30], Train Loss: 0.0050, Train MAE: 4.4938
Val Loss: 0.0021, Val MAE: 2.9147
Saving model...


 43%|████▎     | 13/30 [16:35<20:38, 72.88s/it]

Epoch [14/30], Train Loss: 0.0045, Train MAE: 4.2923
Val Loss: 0.0021, Val MAE: 2.8901
Saving model...


 47%|████▋     | 14/30 [17:47<19:21, 72.58s/it]

Epoch [15/30], Train Loss: 0.0045, Train MAE: 4.2634
Val Loss: 0.0020, Val MAE: 2.8185
Saving model...


 50%|█████     | 15/30 [18:59<18:04, 72.28s/it]

Epoch [16/30], Train Loss: 0.0043, Train MAE: 4.1926
Val Loss: 0.0019, Val MAE: 2.7142
Saving model...


 53%|█████▎    | 16/30 [20:10<16:48, 72.05s/it]

Epoch [17/30], Train Loss: 0.0040, Train MAE: 4.0268
Val Loss: 0.0018, Val MAE: 2.6925
Saving model...


 57%|█████▋    | 17/30 [21:22<15:35, 71.94s/it]

Epoch [18/30], Train Loss: 0.0036, Train MAE: 3.8526
Val Loss: 0.0017, Val MAE: 2.5803
Saving model...


 60%|██████    | 18/30 [22:33<14:21, 71.83s/it]

Epoch [19/30], Train Loss: 0.0037, Train MAE: 3.8640
Val Loss: 0.0016, Val MAE: 2.5107
Saving model...


 63%|██████▎   | 19/30 [23:45<13:09, 71.76s/it]

Epoch [20/30], Train Loss: 0.0032, Train MAE: 3.6281
Val Loss: 0.0015, Val MAE: 2.4760
Saving model...


 67%|██████▋   | 20/30 [24:57<11:56, 71.66s/it]

Epoch [21/30], Train Loss: 0.0033, Train MAE: 3.6572
Val Loss: 0.0014, Val MAE: 2.3816
Saving model...


 70%|███████   | 21/30 [26:08<10:44, 71.61s/it]

Epoch [22/30], Train Loss: 0.0033, Train MAE: 3.6463
Val Loss: 0.0013, Val MAE: 2.3241
Saving model...


 77%|███████▋  | 23/30 [28:31<08:21, 71.65s/it]

Epoch [23/30], Train Loss: 0.0031, Train MAE: 3.5638
Val Loss: 0.0014, Val MAE: 2.3500
Epoch [24/30], Train Loss: 0.0029, Train MAE: 3.4362
Val Loss: 0.0012, Val MAE: 2.2144
Saving model...


 83%|████████▎ | 25/30 [30:55<05:58, 71.60s/it]

Epoch [25/30], Train Loss: 0.0031, Train MAE: 3.4475
Val Loss: 0.0018, Val MAE: 2.5612


 87%|████████▋ | 26/30 [32:06<04:46, 71.67s/it]

Epoch [26/30], Train Loss: 0.0032, Train MAE: 3.5773
Val Loss: 0.0012, Val MAE: 2.2424


 90%|█████████ | 27/30 [33:18<03:34, 71.53s/it]

Epoch [27/30], Train Loss: 0.0028, Train MAE: 3.3299
Val Loss: 0.0012, Val MAE: 2.2488
Epoch [28/30], Train Loss: 0.0026, Train MAE: 3.2327
Val Loss: 0.0011, Val MAE: 2.0802
Saving model...


 93%|█████████▎| 28/30 [34:30<02:23, 71.69s/it]

Epoch [29/30], Train Loss: 0.0024, Train MAE: 3.1773
Val Loss: 0.0010, Val MAE: 2.0421
Saving model...


100%|██████████| 30/30 [36:53<00:00, 73.78s/it]

Epoch [30/30], Train Loss: 0.0023, Train MAE: 3.1150
Val Loss: 0.0010, Val MAE: 2.0523



