In [12]:
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
from numpy import random
from tqdm import tqdm
import os
import torch.nn.functional as F
import torch.optim as optim
from Scripts.salientclassifier import SalientClassifier
from torchsummary import summary
from Scripts.ssi import SalientSuperImage
torch.autograd.set_detect_anomaly(True)
from sklearn.metrics import average_precision_score


##################################################################################################################################################################################

color_jitter = transforms.ColorJitter(random.uniform(0.1, 0.5),random.uniform(0.1, 0.5),random.uniform(0.1, 0.5),random.uniform(0.01, 0.15))

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([color_jitter], p=0.5),
    transforms.RandomAutocontrast(p=0.5),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

###################################################################################################################################################################################

train_ds =  SalientSuperImage(root_dir='/home/jparejo/projects/VD/SaliNet/dataset/SCVD/SCVD_converted_sec_split/Train', num_secs=1, k=12, sampler='uniform', aspect_ratio='480p_A', grid_shape=(4,3), transform=train_transform)
test_ds =  SalientSuperImage(root_dir='/home/jparejo/projects/VD/SaliNet/dataset/SCVD/SCVD_converted_sec_split/Test', num_secs=1, k=12, sampler='uniform', aspect_ratio='480p_A', grid_shape=(4,3), transform=test_transform)

batch_size = 32

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=8)

In [None]:
import torch
from tqdm import tqdm
#
import wandb

def train_step(model: torch.nn.Module,
            dataloader: torch.utils.data.DataLoader,
            loss_fn: torch.nn,
            optimizer: torch.optim,
            device: torch.device) -> tuple:
    train_loss, train_correct, train_total = 0, 0, 0
    for n_batch, (X, y) in enumerate(tqdm(dataloader)):
        # Train step
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        weight_before = model.fc.weight.clone()  # Suponiendo que tienes una capa fc1 en tu modelo
        optimizer.step()
        weight_after = model.fc.weight
        print("Cambio en pesos:", (weight_before - weight_after).abs().mean())
        
        # Update accuracy
        _, predicted = torch.max(y_pred.data, 1)
        train_correct += (predicted == y).sum().item()
        train_total += y.size(0) # Count the number of labels           
        train_loss += loss.item() # Accumulate loss for each batch           
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(f"Layer {name} | Gradient max: {param.grad.max()} | min: {param.grad.min()}")

        
    # calculate average losses
    train_loss /= len(dataloader)
    train_acc = 100. * train_correct / train_total
    return model, train_loss, train_acc

def test_step(model: torch.nn.Module,
            dataloader: torch.utils.data.DataLoader,
            loss_fn: torch.nn,
            device: torch.device,
            retrieve_values: bool = False) -> tuple:

    val_loss, val_correct, val_total = 0, 0, 0
    y_total = []
    y_pred_prob_total = []
    model.to(device)
    with torch.inference_mode():
        for n_batch, (X, y) in enumerate(tqdm(dataloader)):
            X, y = X.to(device), y.to(device)
            
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            y_total.append(y)
            y_pred_prob_total.append(torch.softmax(y_pred, dim=1))
            
            # Update accuracy
            _, predicted = torch.max(y_pred.data, 1)
            val_correct += (predicted == y).sum().item()
            val_total += y.size(0)
            val_loss += loss.item()
            
        # calculate average losses
        val_loss /= len(dataloader)
        val_acc = 100. * val_correct / val_total
    if retrieve_values:
        y_total = torch.cat(y_total, dim=0)
        y_pred_prob_total = torch.cat(y_pred_prob_total, dim=0)
        return val_loss, val_acc, y_total, y_pred_prob_total
    else:
        return val_loss, val_acc

## traces is a dict which define the metrics to be stored and the app to send them to
## traces = {'wandb': {'session': {'project': 'VD', 'group': 'SaliNet', 'name': 'SaliNet-V0'},
#              'config': {'learning_rate': 0.001, 'architecture': 'Salinet-2m', 'dataset': "SCVD", 'epochs': 10, 'footnote': "Some notes."},
#                                   'name': 'SaliNet'},
#           'metrics = ['train_met', 'test_met', 'random_ex', 'pr', 'cm']},
#       'tensorboard': ['performance']}
# Las métricas que guardamos serán las siguientes:
# - train_met: accuracy + loss en el conjunto de entrenamiento
# - test_met: accuracy + loss en el conjunto de test
# - random_ex: ejemplos aleatorios de imágenes y sus predicciones (en prob.)
# - pr: curva PR de ambas clases
# - cm: matriz de confusión
# - performance: tiempos de ejecución de cada componente del modelo
## Hay que preinicializar wandb
                
def train(model: torch.nn.Module,
        train_dataloader: torch.utils.data.DataLoader,
        test_dataloader: torch.utils.data.DataLoader,
        save_path: str,
        device: torch.device,
        loss_fn: torch.nn,
        optimizer: torch.optim,
        lr_scheduler: torch.optim.lr_scheduler = None,
        epochs: int = 100,
        traces: dict = {},
        verbose: bool = True):
    
    ## Initialization 
    best_acc = 0
    model.to(device)
    traces_wandb, traces_tensorboard = traces.get('wandb', []), traces.get('tensorboard', [])
    
    # ## Traces
    # prof = torch.profiler.profile(
    #     schedule = traces_tensorboard['schedule'],
    #     on_trace_ready = torch.profiler.tensorboard_trace_handler(traces_tensorboard['save_path']),
    #     record_shapes=True,
    #     with_stack=False)
    
    # if traces_wandb is not None:
    #     wandb.init(project=traces_wandb['session']['project'],
    #             group=traces_wandb['session']['group'],
    #             name=traces_wandb['session']['name'],
    #             config=traces_wandb['session']['config'])
    
    # prof.start()
    for epoch in range(1, epochs+1):
        last_epoch = epoch == epochs
        # prof.step()
        ## Train
        model.train()
        model, train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)

        # Update learning rate
        if lr_scheduler is not None:
            if verbose:
                print("Updating learning rate.")
            lr_scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning Rate: {current_lr}")
        for name, param in model.named_parameters():
            if param.grad is not None:
                print(f"Layer {name} | Gradient max: {param.grad.max()} | min: {param.grad.min()}")
        
        ## Test
        model.eval()
        val_loss, val_acc, y, y_prob = test_step(model, test_dataloader, loss_fn, device, retrieve_values=True)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model = model.state_dict()
            torch.save(best_model, save_path)

        # # Save traces to wandb
        # if wandb.run is not None:
        #     if 'train_met' in traces_wandb['metrics']:
        #         wandb.log({'train_loss': train_loss, 'train_acc': train_acc}, step=epoch)
        #     if 'test_met' in traces_wandb['metrics']:
        #         wandb.log({'val_loss': val_loss, 'val_acc': val_acc}, step=epoch)
        #     if 'random_ex' in traces_wandb['metrics']:
        #         pass
        #     if 'pr' in traces_wandb['metrics'] and last_epoch: # Only save PR curve in the last epoch
        #         wandb.log({'roc': wandb.plot.roc_curve(y.cpu(),
        #             y_prob.cpu(), labels=["Normal", "Violence"])}, step=epoch)
        #     if 'cm' in traces_wandb['metrics'] and last_epoch: # Only save confusion matrix in the last epoch
        #         wandb.log({'confusion_matrix': wandb.plot.confusion_matrix(
        #             y_prob.cpu(), y.cpu().tolist(), class_names=["Normal", "Violence"])}, step=epoch)
        #     if verbose and last_epoch:
        #         print(f"Epoch {epoch}/{epochs}: Metrics sent to wandb.")
        # else:
        #     print("No wandb run detected. Skipping logging to wandb.")
        
        # # Save traces to tensorboard
        # if 'performance' in traces_tensorboard:
        #     pass
            
        if verbose:
            print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss:.4f} || Train Acc: {train_acc:.2f}% || Val Loss: {val_loss:.4f} || Val Acc: {val_acc:.2f}%")
    # wandb.finish()
    # prof.stop()
    print("Finished Training.")
    return model

In [19]:
for name, param in model.model.named_parameters():
    print(name)

conv1.weight
conv1.kernel_fn.cp
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer3.0.conv1.weight
layer3.0.bn1.weight
layer3.0.bn1.bias
layer3.0.downsample.0.weight
layer3.0.downsample.1.weight
layer3.0.downsample.1.bias
layer4.0.conv1.weight
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.downsample.0.weight
layer4.0.downsample.1.weight
layer4.0.downsample.1.bias
fc.weight
fc.bias


In [15]:
from torchinfo import summary

def get_model(kind: str, verbose: bool = True) -> torch.nn.Module:
    match kind:
        case 'paper':
            model = SalientClassifier("salinet2m", num_classes=3)
            checkpoint = torch.load("weights/SalientClassifier-Salinet2m-SCVD.pth")
            model.load_state_dict(checkpoint)
            for param in model.model.parameters():
                param.requires_grad = False
            num_classes = 2
            model.model.fc = nn.Linear(model.model.fc.in_features, num_classes)
        case 'scratch':
            model = SalientClassifier("salinet2m", num_classes=2)
        case 'keep_training':
            model = SalientClassifier("salinet2m", num_classes=2)
            checkpoint = torch.load("weights/SalientClassifier-caps.pth")
            model.load_state_dict(checkpoint)
            for name, param in model.model.named_parameters():
                if ("fc" not in name):
                    param.requires_grad = False
        case _:
            raise ValueError("Invalid model kind.")
    if verbose:
        summary(model, input_size = (4, 3, 672, 672), verbose=0, col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20)
    return model

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define loss functions
ce_loss = nn.CrossEntropyLoss()

model = get_model('paper')

# Define optimizer and learning rate scheduler
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

num_epochs = 4

  checkpoint = torch.load("weights/SalientClassifier-Salinet2m-SCVD.pth")


In [17]:
model = train(model,
    train_dataloader=train_loader,
    test_dataloader=test_loader,
    save_path = "./weights/SalientClassifier-caps.pth", ### <-
    loss_fn=ce_loss,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    epochs=num_epochs,
    device=device,
    traces = {},
    verbose=True)

  0%|          | 0/58 [00:23<?, ?it/s]


AttributeError: 'SalientClassifier' object has no attribute 'fc1'