**Step 0 :** Prep the data frame for dataloader creation

In [1]:
import pandas as pd
from datetime import datetime
from sklearn.preprocessing import LabelEncoder
from sentence_transformers import SentenceTransformer
from tqdm import tqdm


df = pd.read_csv('../dataset/train_val.csv')


df['date'] = pd.to_datetime(df['date'])
# Extract date features
df['year'] = df['date'].dt.year
df['month'] = df['date'].dt.month
df['day'] = df['date'].dt.day
df['day_of_week'] = df['date'].dt.dayofweek
df['quarter'] = df['date'].dt.quarter

# Check if date has timezone information
has_tz = df['date'].dt.tz is not None

# Fix the timezone issue with days_since_upload calculation
if has_tz:
    # Method 1: Make reference date timezone-aware
    from datetime import timezone
    reference_date = datetime.now(timezone.utc)
else:
    # Method 2: Use a timezone-naive reference date
    reference_date = pd.Timestamp.now().tz_localize(None)


# Calculate days since upload (using the date of the most recent video as reference)
df['days_since_upload'] = (reference_date - df['date']).dt.days

# Encode channel names
channel_encoder = LabelEncoder()
df['channel_encoded'] = channel_encoder.fit_transform(df['channel'])

sentence_encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
all_embeddings = []

for idx in tqdm(range(len(df))):
    row = df.iloc[idx]
    # Get embeddings using your existing method
    title = row['title']
    description = row['description']
    channel = row['channel']
    date = row['date']
    field_embeddings = sentence_encoder.encode([title, description, channel, date], convert_to_tensor=True)
    # Store them in numpy arrays
    all_embeddings.append(field_embeddings.cpu().numpy())
    
df['embeddings'] = all_embeddings    
df.head()

100%|██████████| 15482/15482 [01:14<00:00, 207.01it/s]


Unnamed: 0.1,Unnamed: 0,id,channel,title,date,description,views,year,month,day,day_of_week,quarter,days_since_upload,channel_encoded,embeddings
0,0,--2s6hjGrm4,UC-1rx8j9Ggp8mp4uD0ZdEIA,"CGI & VFX Breakdowns: ""Warzone"" - by Ramesh Th...",2020-12-15 05:00:01+00:00,"Check out this revealing VFX Breakdown ""Warzon...",12299,2020,12,15,1,4,1611,0,"[[-0.07760659, -0.001022775, -0.09010337, -0.0..."
1,1,--DnfroyKQ8,UC-1rx8j9Ggp8mp4uD0ZdEIA,"A Sci-Fi Short Film: ""Exit"" - by Ng King Kwan ...",2020-07-01 16:00:00+00:00,"TheCGBros Presents ""Exit"" by Ng King Kwan - Th...",7494,2020,7,1,2,3,1777,0,"[[-0.022426384, 0.05459995, -0.0177436, 0.0594..."
2,2,--aiU7VQKEw,UC-1rx8j9Ggp8mp4uD0ZdEIA,"CGI 3D Animated Short: ""Lost Love"" - by Akash ...",2019-02-18 20:30:00+00:00,"TheCGBros Presents ""Lost Love"" by Akash Manack...",11831,2019,2,18,0,1,2276,0,"[[-0.11143896, 0.022581432, 0.016571341, -0.02..."
3,6,-0SrlZAvSVM,UCW6NyJ6oFLPTnx7iGRZXDDg,Jo Goes Hunting - Careful | Animated music vid...,2020-03-10 14:30:01+00:00,"On the borderless map of a magical planet, lit...",2248,2020,3,10,1,1,1890,28,"[[-0.021549331, 0.040397692, -0.0008517903, -0..."
4,10,-13Y2Pe7kFs,UC-1rx8j9Ggp8mp4uD0ZdEIA,"CGI VFX Breakdown: ""Logan (Wolverine): Digital...",2017-09-20 20:13:52+00:00,Check out this outstanding behind-the-scenes l...,113806,2017,9,20,2,3,2792,0,"[[-0.08767335, -0.07205786, 0.027961658, -0.06..."


**Step 1 :** Instantiate Dataloader

In [2]:
BATCH_SIZE = 128
EPOCHS = 50
NUMBER_WORKERS = 10
PRE_FETCH = 4

In [3]:
from sklearn.model_selection import train_test_split
from dataset import MultiModalDataset
import torchvision.transforms as T
from torch.utils.data import DataLoader

transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # mimic thumbnail color pop
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3),
]) 

# Split the data into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
train_dataset = MultiModalDataset(train_df, image_transform=transform)
val_dataset = MultiModalDataset(val_df)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUMBER_WORKERS, prefetch_factor=PRE_FETCH, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUMBER_WORKERS, prefetch_factor=PRE_FETCH, pin_memory=True, persistent_workers=True)




**Step 2:** Instantiate the model and optimizer

In [4]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
import gc
import torch
from model import FullModel

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = FullModel()
model.to(device)

# Optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss() #HuberLoss(delta=1.0)

# Callback function
train_losses = []
val_losses = []
batch_losses = []

def plot_losses():
    clear_output(wait=True)
    plt.figure(figsize=(12, 5))
    
    # Plot batch losses
    plt.subplot(1, 2, 1)
    plt.plot(batch_losses)
    plt.title('Loss per Batch')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.grid(True)
    
    # Plot epoch losses
    plt.subplot(1, 3, 2)
    epochs = range(1, len(train_losses) + 1)
    plt.plot(epochs, train_losses, 'b-', label='Training')
    plt.plot(epochs, val_losses, 'r-', label='Validation')
    plt.title('Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()


ModuleNotFoundError: No module named 'topic_encoding'

**Step 2.5 :** Load all pretrained models 

In [None]:
model.baseline.load_state_dict(torch.load('pretrained_weights/best_views_predictor.pth'))
model.vision.load_state_dict(torch.load('pretrained_weights/flashiness_encoder.pth'))
model.topics.load_state_dict(torch.load('pretrained_weights/metadata_fusion.pth'))

**Step 3:** Training loop

In [None]:
EPOCHS = 50

In [None]:
import os
import torch.multiprocessing as mp
if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0.0
    for i, (channel_id, numeric_features, image_features, topic_embeddings, views) in enumerate(train_loader):
        channel_id = channel_id.to(device)
        numeric_features = numeric_features.to(device)
        image_features = image_features.to(device)
        topic_embeddings = topic_embeddings.to(device)
        views = views.to(device)
        
        fields = (channel_id, numeric_features, image_features, topic_embeddings)
        targets = views

        optimizer.zero_grad()
        preds = model(fields)
        preds = preds.squeeze(-1) # Remove extra dimension
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        batch_losses.append(batch_loss)
        total_train_loss += loss.item() * len(fields)
        
        # Move tensors to CPU to free GPU memory
        fields = fields.cpu()
        targets = targets.cpu()
        
        
            
    # Validation phase
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for channel_id, numeric_features, image_features, topic_embeddings, targets in val_loader:
            fields = fields.to(device)
            targets = targets.to(device)
            
            outputs = model(fields)
            outputs = outputs.squeeze(-1) 
            val_loss = loss_fn(outputs, targets)
            total_val_loss += val_loss.item() * len(fields)
            
            # Move tensors to CPU to free GPU memory
            fields = fields.cpu()
            targets = targets.cpu()


    # Calculate average losses
    avg_train_loss = total_train_loss / len(train_dataset)
    avg_val_loss = total_val_loss / len(val_dataset)
    
    # Store losses for plotting
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    if epoch % 20 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    # Free up memory
    torch.cuda.empty_cache()
    gc.collect()
    
    plot_losses()
    
plot_losses()