# still trying to figure out why old training was better

In [44]:
from configs import Config
from train import train_run_r3d18_1, run_2
from torchvision.transforms import v2
import os
from video_dataset import VideoDataset
from torch.utils.data import DataLoader
import torch
import models.pytorch_r3d as resnet_3d
import torch.optim as optim
import torch.nn as nn
import numpy as np
import random
from utils import enum_dir
from torch.utils.tensorboard import SummaryWriter 

In [2]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed()

In [3]:
root = '../data/WLASL2000'
split = 'asl100'
labels = f'./preprocessed/labels/{split}'
output = f'runs/{split}/compr3d18_000'

conf_pathA = f'./configfiles/{split}/r3d18_005.ini'
conf_pathB = f'./configfiles/{split}/r3d18_007.ini'

confA = Config(conf_pathA)
confB = Config(conf_pathB)

print(confA)
print('-'*10)
print(confB)

Config:
			Training: bs=6, steps=64000, ups=1
			Optimizer: lr=0.001, eps=0.001, wd=1e-07
----------
Config:
			Model: <class 'models.pytorch_r3d.Resnet3D18_basic'>
			Weights Path: None
			Frozen layers: []
			Scheduler: t_max=100, eta_min=1e-06
			Training: bs=6, steps=64000, ups=1
			Optimizer: lr=0.001, eps=0.001, wd=1e-07
			Backbone: lr=1e-05, wd=0.0001
			Classifier: lr=0.001, wd=0.0001


## Comparing transforms

### train.train_run_r3d18_1

In [4]:
base_mean = [0.43216, 0.394666, 0.37645]
base_std = [0.22803, 0.22145, 0.216989]


r3d18_final = v2.Compose([
  v2.Lambda(lambda x: x.float() / 255.0),
  # v2.Lambda(lambda x: vt.normalise(x, base_mean, base_std)),
  v2.Normalize(mean=base_mean, std=base_std),
  v2.Lambda(lambda x: x.permute(1,0,2,3)) 
])

#setup dataset 
train_transformsA = v2.Compose([v2.RandomCrop(224),
                                v2.RandomHorizontalFlip(),
                                r3d18_final])
test_transformsA = v2.Compose([v2.CenterCrop(224),
                              r3d18_final])

In [5]:
print(train_transformsA)
print(test_transformsA)

Compose(
      RandomCrop(size=(224, 224), pad_if_needed=False, fill=0, padding_mode=constant)
      RandomHorizontalFlip(p=0.5)
      Compose(
        Lambda(<lambda>, types=['object'])
        Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989], inplace=False)
        Lambda(<lambda>, types=['object'])
  )
)
Compose(
      CenterCrop(size=(224, 224))
      Compose(
        Lambda(<lambda>, types=['object'])
        Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989], inplace=False)
        Lambda(<lambda>, types=['object'])
  )
)


### train.run_2

In [6]:
train_transformsB, test_transformsB = confB.get_transforms()

In [7]:
print(train_transformsB)
print(test_transformsB)

Compose(
      RandomCrop(size=(224, 224), pad_if_needed=False, fill=0, padding_mode=constant)
      RandomHorizontalFlip(p=0.5)
      Compose(
        Lambda(<lambda>, types=['object'])
        Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989], inplace=False)
        Lambda(<lambda>, types=['object'])
  )
)
Compose(
      CenterCrop(size=(224, 224))
      Compose(
        Lambda(<lambda>, types=['object'])
        Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989], inplace=False)
        Lambda(<lambda>, types=['object'])
  )
)


In [8]:
A_objs = [ test_transformsA, train_transformsA]
B_objs = [ test_transformsB, train_transformsB]
for i, (obja, objb) in enumerate(zip(A_objs, B_objs)):
  print(i)
  assert(str(obja) == str(objb))

0
1


## Setup data


In [9]:
train_instances = os.path.join(labels, 'train_instances_fixed_frange_bboxes_len.json')
val_instances = os.path.join(labels,'val_instances_fixed_frange_bboxes_len.json' )
train_classes = os.path.join(labels, 'train_classes_fixed_frange_bboxes_len.json')
val_classes = os.path.join(labels,'val_classes_fixed_frange_bboxes_len.json' )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
print(train_instances)
print(train_classes)
print(val_instances)
print(val_classes)
print(device)

./preprocessed/labels/asl100/train_instances_fixed_frange_bboxes_len.json
./preprocessed/labels/asl100/train_classes_fixed_frange_bboxes_len.json
./preprocessed/labels/asl100/val_instances_fixed_frange_bboxes_len.json
./preprocessed/labels/asl100/val_classes_fixed_frange_bboxes_len.json
cuda


### train.train_run_r3d18_1


In [11]:
train_setA = VideoDataset(root,train_instances, train_classes,
                          transforms=train_transformsA, num_frames=32)
train_loaderA = DataLoader(train_setA, batch_size=confA.batch_size,
    shuffle=True, num_workers=2,pin_memory=True)
num_classesA  = len(set(train_setA.classes))

val_setA = VideoDataset(root, val_instances, val_classes,
    transforms=test_transformsA, num_frames=32)
val_loaderA = DataLoader(val_setA,
    batch_size=confA.batch_size, shuffle=True, num_workers=2,pin_memory=False)
assert num_classesA == len(set(val_setA.classes))

dataloadersA = {'train': train_loaderA, 'val':val_loaderA}

In [12]:
print(train_setA)
print(train_loaderA)
print(num_classesA)
print(val_setA)
print(val_loaderA)

<video_dataset.VideoDataset object at 0x7178445a7490>
<torch.utils.data.dataloader.DataLoader object at 0x7179e1507880>
100
<video_dataset.VideoDataset object at 0x7179e1507970>
<torch.utils.data.dataloader.DataLoader object at 0x7179e14cfca0>


### train.run_2

In [13]:
train_setB = VideoDataset(root,train_instances, train_classes,
                          transforms=train_transformsB, num_frames=confB.num_frames)
train_loaderB = DataLoader(train_setB, batch_size=confB.batch_size,
                          shuffle=True, num_workers=2,pin_memory=True)
num_classesB = len(set(train_setB.classes))

val_setB = VideoDataset(root, val_instances, val_classes,
    transforms=test_transformsB, num_frames=confB.num_frames)
val_loaderB = DataLoader(val_setB,
    batch_size=confB.batch_size, shuffle=True, num_workers=2,pin_memory=False)
val_classesB = len(set(val_setB.classes))
assert num_classesB == val_classesB 
assert num_classesB == confB.num_classes

dataloadersB = {'train': train_loaderB, 'val': val_loaderB}

In [14]:
print(train_setB)
print(train_loaderB)
print(num_classesB)
print(val_setB)
print(val_loaderB)

<video_dataset.VideoDataset object at 0x7178445a79d0>
<torch.utils.data.dataloader.DataLoader object at 0x7178445a7fa0>
100
<video_dataset.VideoDataset object at 0x7179f30ce290>
<torch.utils.data.dataloader.DataLoader object at 0x7179e14ceb00>


## Setup model

### train.train_run_r3d18_1

In [15]:
r3d18A = resnet_3d.Resnet3D18_basic(num_classes=num_classesA,
                                    drop_p=confA.drop_p)
print(r3d18A)

Resnet3D18_basic(num_classes=100,
      drop_p=0.3)
        Model architecture:
          Backbone: Sequential(
  (0): BasicStem(
    (0): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(6

### train.run_2

In [16]:
r3d18B = confB.create_model()
print(r3d18B)

Resnet3D18_basic(num_classes=100,
      drop_p=0.3)
        Model architecture:
          Backbone: Sequential(
  (0): BasicStem(
    (0): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(6

In [17]:
assert(str(r3d18A) == str(r3d18B))

## Setup optimizer

### train.train_runr3d18_1 

In [18]:
param_groupsA = [
    {
      'params': r3d18A.backbone.parameters(),
      'lr': 1e-5,  # Low LR for pretrained backbone
      'weight_decay': 1e-4
    },
    {
      'params': r3d18A.classifier.parameters(), 
      'lr': 1e-3,  # Higher LR for new classifier
      'weight_decay': 1e-4
    }
  ]

optimizerA = optim.AdamW(param_groupsA, betas=(0.9, 0.999))

In [19]:
print(param_groupsA)
print()
print(optimizerA)

[{'params': [Parameter containing:
tensor([[[[[-2.9227e-02, -4.1104e-02, -4.1792e-02,  ..., -5.1550e-02,
            -3.8930e-02, -3.9654e-02],
           [-2.1000e-02, -3.2783e-02, -3.3522e-02,  ..., -4.3719e-02,
            -3.6257e-02, -3.5357e-02],
           [-1.5999e-02, -2.7649e-02, -2.6741e-02,  ..., -3.6765e-02,
            -3.1091e-02, -2.8574e-02],
           ...,
           [ 2.6655e-02,  2.2128e-02,  2.8747e-02,  ...,  2.9469e-02,
             2.4539e-02,  2.0806e-02],
           [ 3.5183e-02,  3.2908e-02,  4.2953e-02,  ...,  4.5522e-02,
             3.8277e-02,  3.5271e-02],
           [ 4.1102e-02,  3.6569e-02,  4.1683e-02,  ...,  5.0102e-02,
             4.4259e-02,  4.4615e-02]],

          [[-4.8014e-02, -6.8868e-02, -6.9464e-02,  ..., -7.7126e-02,
            -6.2017e-02, -6.0679e-02],
           [-3.2114e-02, -5.0651e-02, -5.4895e-02,  ..., -6.4078e-02,
            -5.5314e-02, -5.2160e-02],
           [-2.0903e-02, -3.7985e-02, -4.1331e-02,  ..., -4.9900e-02,
     

### train.run_2

In [20]:
param_groupsB = [ 
  {
    'params': r3d18B.backbone.parameters(),
    'lr': confB.backbone_init_lr,  # Low LR for pretrained backbone
    'weight_decay': confB.backbone_weight_decay
  },
  {
    'params': r3d18B.classifier.parameters(), 
    'lr': confB.classifier_init_lr,  # Higher LR for new classifier
    'weight_decay': confB.classifier_weight_decay
  }
]

# optimizer = optim.AdamW(param_groupsB, eps=confB.adam_eps) this was only done for exp11
optimizerB = optim.AdamW(param_groupsB) #this was for exp7

In [30]:
print(param_groupsB)
print()
print(optimizerB)

[{'params': [Parameter containing:
tensor([[[[[-2.9227e-02, -4.1104e-02, -4.1792e-02,  ..., -5.1550e-02,
            -3.8930e-02, -3.9654e-02],
           [-2.1000e-02, -3.2783e-02, -3.3522e-02,  ..., -4.3719e-02,
            -3.6257e-02, -3.5357e-02],
           [-1.5999e-02, -2.7649e-02, -2.6741e-02,  ..., -3.6765e-02,
            -3.1091e-02, -2.8574e-02],
           ...,
           [ 2.6655e-02,  2.2128e-02,  2.8747e-02,  ...,  2.9469e-02,
             2.4539e-02,  2.0806e-02],
           [ 3.5183e-02,  3.2908e-02,  4.2953e-02,  ...,  4.5522e-02,
             3.8277e-02,  3.5271e-02],
           [ 4.1102e-02,  3.6569e-02,  4.1683e-02,  ...,  5.0102e-02,
             4.4259e-02,  4.4615e-02]],

          [[-4.8014e-02, -6.8868e-02, -6.9464e-02,  ..., -7.7126e-02,
            -6.2017e-02, -6.0679e-02],
           [-3.2114e-02, -5.0651e-02, -5.4895e-02,  ..., -6.4078e-02,
            -5.5314e-02, -5.2160e-02],
           [-2.0903e-02, -3.7985e-02, -4.1331e-02,  ..., -4.9900e-02,
     

In [23]:
assert(str(param_groupsA[0]) == str(param_groupsB[0]))

In [24]:
assert(str(param_groupsA[1]) == str(param_groupsB[1]))

AssertionError: 

In [25]:
assert(str(param_groupsA[1]['lr']) == str(param_groupsB[1]['lr']))

In [26]:
assert(str(param_groupsA[1]['weight_decay']) == str(param_groupsB[1]['weight_decay']))

In [27]:
assert(str(param_groupsA[1]['params']) == str(param_groupsB[1]['params']))

AssertionError: 

In [28]:
print(param_groupsA[1]['params'])

[Parameter containing:
tensor([[-0.0022,  0.0207,  0.0352,  ..., -0.0076,  0.0368,  0.0314],
        [-0.0434, -0.0418, -0.0288,  ...,  0.0178,  0.0374,  0.0008],
        [ 0.0032,  0.0334, -0.0286,  ...,  0.0384,  0.0039,  0.0148],
        ...,
        [ 0.0247,  0.0364, -0.0428,  ...,  0.0384, -0.0252,  0.0086],
        [-0.0079,  0.0195, -0.0439,  ..., -0.0061,  0.0225, -0.0231],
        [-0.0130, -0.0331,  0.0213,  ...,  0.0323, -0.0092, -0.0186]],
       requires_grad=True), Parameter containing:
tensor([-0.0074, -0.0364, -0.0029, -0.0156,  0.0441, -0.0287,  0.0432,  0.0301,
        -0.0182, -0.0147, -0.0097,  0.0340,  0.0408,  0.0355, -0.0392, -0.0130,
        -0.0301,  0.0011, -0.0401,  0.0360, -0.0115,  0.0350,  0.0259, -0.0092,
         0.0030, -0.0340, -0.0192,  0.0002,  0.0095,  0.0149, -0.0281,  0.0053,
         0.0303,  0.0032, -0.0035, -0.0148,  0.0431,  0.0033,  0.0123, -0.0306,
         0.0055,  0.0195, -0.0094, -0.0235, -0.0135,  0.0187,  0.0352, -0.0373,
        -0.01

In [29]:
print(param_groupsB[1]['params'])

[Parameter containing:
tensor([[-0.0288, -0.0007,  0.0292,  ..., -0.0291,  0.0420, -0.0233],
        [ 0.0042, -0.0171,  0.0088,  ..., -0.0033, -0.0029, -0.0025],
        [ 0.0196, -0.0416,  0.0122,  ..., -0.0129,  0.0331,  0.0009],
        ...,
        [-0.0134,  0.0179, -0.0413,  ..., -0.0021, -0.0004,  0.0365],
        [ 0.0426, -0.0141,  0.0098,  ...,  0.0365,  0.0130, -0.0208],
        [-0.0166, -0.0070, -0.0409,  ..., -0.0138,  0.0391, -0.0236]],
       requires_grad=True), Parameter containing:
tensor([-8.5526e-03, -7.5515e-03,  3.5958e-02,  1.3158e-02, -1.7032e-02,
        -9.2229e-03,  1.6658e-02,  1.5513e-02, -6.6410e-03, -6.8995e-03,
        -8.8408e-03,  3.0936e-03, -4.0808e-02, -1.7917e-02, -8.0460e-03,
         4.2457e-03, -4.1076e-02, -4.0199e-02,  1.2653e-02, -3.6136e-02,
         1.7476e-02, -3.5973e-03, -3.4126e-02, -1.5499e-02, -3.7468e-02,
        -2.9114e-02,  7.0972e-03, -1.3015e-03, -2.9299e-02, -4.3318e-02,
         2.4420e-02,  2.8815e-02,  4.3356e-02, -1.0356e

i suppose not too unusual that the classifier parameters don't match, but do they have the same shapes

In [31]:
assert(str(optimizerA) == str(optimizerB))

## Setup scheduler & loss func

### train.train_r3d18_1

In [33]:
schedulerA = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerA,
                                                        T_max=100,
                                                        eta_min=1e-6)
loss_funcA = nn.CrossEntropyLoss()

In [34]:
print(schedulerA)
print()
print(loss_funcA)

<torch.optim.lr_scheduler.CosineAnnealingLR object at 0x717830429150>

CrossEntropyLoss()


### train.run_2

In [37]:
schedulerB = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerB,
                                                         T_max=confB.t_max,
                                                         eta_min=confB.eta_min)
loss_funcB = nn.CrossEntropyLoss()
  

In [39]:
print(schedulerB)
print()
print(loss_funcB)

<torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7178305be7d0>

CrossEntropyLoss()


## Training loop:

In [None]:
load = None
if output:
  if load is None: #fresh run, fresh folder
    output = enum_dir(output, make=True) 
  print(f"Output directory set to: {output}")
save_every=5

Output directory set to: runs/asl100/compr3d18_000


In [47]:
saveA='checkpointsA'
saveB='checkpointsB'
# for save in [saveA, saveB]:
# # if save:
#   save_path = os.path.join(output, save)
#   if load is None:
#     save_path = enum_dir(save_path, make=True)
#   print(f"Save directory set to: {save_path}")
save_pathA = os.path.join(output, saveA)
save_pathB = os.path.join(output, saveB)
print(f"Save directory set to: {save_pathA}")
print(f"Save directory set to: {save_pathB}")

Save directory set to: runs/asl100/compr3d18_000/checkpointsA
Save directory set to: runs/asl100/compr3d18_000/checkpointsB


In [45]:
logsA = 'logsA'
logsB = 'logsB'
# if logs:
# for logs in [logsA, logsB]:
#   logs_path = os.path.join(output, logs)
#   if load is None:
#     logs_path = enum_dir(logs_path, make=True)
#   print(f"Logs directory set to: {logs_path}")
logs_pathA = os.path.join(output, logsA)
logs_pathB = os.path.join(output, logsB)
writerA = SummaryWriter(logs_pathA) 
writerB = SummaryWriter(logs_pathB) 
print(f"Logs directory set to: {logs_pathA}")
print(f"Logs directory set to: {logs_pathB}")


Logs directory set to: runs/asl100/compr3d18_000/logsA
Logs directory set to: runs/asl100/compr3d18_000/logsB


### train.train_runr3d18_1

In [None]:
def train_loop_A(r3d18A, confA,optimizerA, loss_funcA, device,
                 dataloadersA, logsA, writerA, save_pathA, saveA,
                 schedulerA, save_every=5, max_epoch=400):
  r3d18A.to(device)
  steps=0
  epoch=0
  best_val_score=0
  num_steps_per_update = confA.update_per_step
  
  while steps < confA.max_steps and epoch < max_epoch:
    print(f'Step {steps}/{confA.max_steps}')
    print('-'*10)
    
    epoch+=1
    #each epoch has training and validation stage
    for phase in ['train', 'val']:
      
      if phase == 'train':
        r3d18A.train()
      else:
        r3d18A.eval()
        
      #Reset matrics for this phase
      running_loss = 0.0
      running_corrects = 0
      total_samples = 0
      num_batches = 0
      # tot_loc_loss = 0.0  #TODO once this gets working try the fancy loss
      # tot_cls_loss = 0.0
      
      #for gradient accumulation  
      accumulated_loss = 0.0
      accumulated_steps = 0
      optimizerA.zero_grad()
    
      #Iterate over data for this phase
      # for batch_idx, (data, target) in enumerate(dataloaders[phase]):
      for batch_idx, item in enumerate(dataloadersA[phase]):
        data, target = item['frames'], item['label_num'] #for compatibility
        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)
        total_samples += batch_size
        num_batches += 1
        
        #Forward pass
        if phase == 'train':
          model_output = r3d18A(data)
        else:
          with torch.no_grad():
            model_output = r3d18A(data)
            
        # Calculate loss
        loss = loss_funcA(model_output, target)

        #Accumulate metrics
        running_loss += loss.item() * batch_size  
        _, predicted = model_output.max(1)
        running_corrects += predicted.eq(target).sum().item()
        

        if phase == 'train':
          scaled_loss = loss / num_steps_per_update
          scaled_loss.backward()
          
          accumulated_loss += loss.item()
          accumulated_steps += 1
          
          if accumulated_steps == num_steps_per_update:
            optimizerA.step()
            optimizerA.zero_grad()
            steps += 1
            
            # Print progress every few steps
            if steps % 10 == 0:
              avg_acc_loss = accumulated_loss / accumulated_steps
              current_acc = 100.0 * running_corrects / total_samples
              print(f'Step {steps}: Accumulated Loss: {avg_acc_loss:.4f}, '
                    f'Current Accuracy: {current_acc:.2f}%')
              
              if logsA:
                writerA.add_scalar('Loss/Train_Step', avg_acc_loss, steps) 
                writerA.add_scalar('Accuracy/Train_Step', current_acc, steps)
            
            # Reset accumulation
            accumulated_loss = 0.0
            accumulated_steps = 0
    
      #calculate  epoch metrics
      epoch_loss = running_loss / total_samples # Average loss per sample
      epoch_acc = 100.0 * running_corrects / total_samples

      print(f'{phase.upper()} - Epoch {epoch}:')
      print(f'  Loss: {epoch_loss:.4f}')
      print(f'  Accuracy: {epoch_acc:.2f}% ({running_corrects}/{total_samples})')
      
      # Log epoch metrics
      if logsA:
        writerA.add_scalar(f'Loss/{phase.capitalize()}', epoch_loss, epoch) 
        writerA.add_scalar(f'Accuracy/{phase.capitalize()}', epoch_acc, epoch) 
      
      # Validation specific logic
      if phase == 'val':
          # Save best model
          if epoch_acc > best_val_score:
              best_val_score = epoch_acc
              model_name = os.path.join(save_pathA, f'best.pth') 
              torch.save(r3d18A.state_dict(), model_name)
              print(f'New best model saved: {model_name} (Acc: {epoch_acc:.2f}%)')
          
          # Step scheduler with validation loss
          # scheduler.step(epoch_loss) # type: ignore
          schedulerA.step() 
          
          print(f'Best validation accuracy so far: {best_val_score:.2f}%')
      
      # Save checkpoint
    if saveA and (epoch % save_every == 0 or not (steps < confA.max_steps and epoch < 400)):
        checkpoint_data = {
            'epoch': epoch,
            'steps': steps,
            'model_state_dict': r3d18A.state_dict(),
            'optimizer_state_dict': optimizerA.state_dict(),
            'scheduler_state_dict': schedulerA.state_dict(),
            'best_val_score': best_val_score
        }
        checkpoint_path = os.path.join(save_path, f'checkpoint_{epoch}.pth') # type: ignore
        torch.save(checkpoint_data, checkpoint_path)
        print(f'Checkpoint saved: {checkpoint_path}')
        
    
  print('Finished training successfully')
  

### train.run_2

In [None]:
def train_loop_B(r3d18B,device, confB, dataloadersB, optimizerB, schedulerB,
                 loss_funcB, writerB, logsB, saveB, save_pathB,
                 save_every=5, max_epochs=400):
  r3d18B.to(device)
  steps=0
  epoch=0
  best_val_score=0

  while steps < confB.max_steps and epoch < max_epochs:
    print(f'Step {steps}/{confB.max_steps}')
    print('-'*10)
    
    epoch+=1
    #each epoch has training and validation stage
    for phase in ['train', 'val']:
      
      if phase == 'train':
        r3d18B.train()
      else:
        r3d18B.eval()
        
      #Reset matrics for this phase
      running_loss = 0.0
      running_corrects = 0
      total_samples = 0
      # num_batches = 0
      # tot_loc_loss = 0.0  #TODO once this gets working try the fancy loss
      # tot_cls_loss = 0.0
      
      #for gradient accumulation  
      accumulated_loss = 0.0
      accumulated_steps = 0
      optimizerB.zero_grad()
    
      #Iterate over data for this phase
      for batch_idx, item in enumerate(dataloadersB[phase]):
        data, target = item['frames'], item['label_num']
        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)
        total_samples += batch_size
        # num_batches += 1
        
        #Forward pass
        if phase == 'train':
          model_output = r3d18B(data)
        else:
          with torch.no_grad():
            model_output = r3d18B(data)
            
        # Calculate loss
        loss = loss_funcB(model_output, target)

        #Accumulate metrics
        running_loss += loss.item() * batch_size  
        _, predicted = model_output.max(1)
        running_corrects += predicted.eq(target).sum().item()
        

        if phase == 'train':
          scaled_loss = loss / confB.update_per_step
          scaled_loss.backward()
          
          accumulated_loss += loss.item()
          accumulated_steps += 1
          
          if accumulated_steps == confB.update_per_step:
            optimizerB.step()
            optimizerB.zero_grad()
            steps += 1
            
            # Print progress every few steps
            if steps % 10 == 0:
              avg_acc_loss = accumulated_loss / accumulated_steps
              current_acc = 100.0 * running_corrects / total_samples
              print(f'Step {steps}: Accumulated Loss: {avg_acc_loss:.4f}, '
                    f'Current Accuracy: {current_acc:.2f}%')
              
              if logsB:
                writerB.add_scalar('Loss/Train_Step', avg_acc_loss, steps) 
                writerB.add_scalar('Accuracy/Train_Step', current_acc, steps) 
            
            # Reset accumulation
            accumulated_loss = 0.0
            accumulated_steps = 0
    
      #calculate  epoch metrics
      epoch_loss = running_loss / total_samples # Average loss per sample
      epoch_acc = 100.0 * running_corrects / total_samples

      print(f'{phase.upper()} - Epoch {epoch}:')
      print(f'  Loss: {epoch_loss:.4f}')
      print(f'  Accuracy: {epoch_acc:.2f}% ({running_corrects}/{total_samples})')
      try:
        for i, param_group in enumerate(optimizerB.param_groups):
          if logsB:
            writerB.add_scalar(f'LearningRate/Group_{i}', param_group['lr'], epoch) 
          print(f"Group {i} learning rate: {param_group['lr']}")
      except Exception as e:
        print(f'Failed to print all learning rates due to {e}')
        
      # Log epoch metrics
      if logsB:
        writerB.add_scalar(f'Loss/{phase.capitalize()}', epoch_loss, epoch) 
        writerB.add_scalar(f'Accuracy/{phase.capitalize()}', epoch_acc, epoch) 
      
      # Validation specific logic
      if phase == 'val':
          # Save best model
          if epoch_acc > best_val_score:
              best_val_score = epoch_acc
              model_name = os.path.join(save_pathB, f'best.pth') 
              torch.save(r3d18B.state_dict(), model_name)
              print(f'New best model saved: {model_name} (Acc: {epoch_acc:.2f}%)')
          
          # Step scheduler with validation loss
          schedulerB.step() 
          
          print(f'Best validation accuracy so far: {best_val_score:.2f}%')
      
      # Save checkpoint
    if saveB and (epoch % save_every == 0 or not (steps < confB.max_steps and epoch < 400)):
        checkpoint_data = {
            'epoch': epoch,
            'steps': steps,
            'model_state_dict': r3d18B.state_dict(),
            'optimizer_state_dict': optimizerB.state_dict(),
            'scheduler_state_dict': schedulerB.state_dict(),
            'best_val_score': best_val_score
        }
        checkpoint_path = os.path.join(save_path, f'checkpoint_{str(epoch).zfill(3)}.pth') # type: ignore
        torch.save(checkpoint_data, checkpoint_path)
        print(f'Checkpoint saved: {checkpoint_path}')

  print('Finished training successfully')
  if logsB:
    writerB.close()

## set to train

In [None]:
train_loop_A(r3d18A, confA, optimizerA, loss_funcA, device, dataloadersA,
             logsA, writerA, save_pathA, saveA, schedulerA, max_epoch=100) #lets not run for too long
train_loop_B(r3d18B, device, confB, dataloadersB, optimizerB, schedulerB,
             loss_funcB, writerB, logsB, saveB, save_pathB, max_epochs=100)