In [1]:
import os
import subprocess
import sys

# ========== CONFIGURATION ==========
REPO_URL = "https://github.com/MuhammadQaiser1921/swin-model.git"
REPO_NAME = "swin-model"
REPO_BRANCH = "main"  # <--- Specify your branch here
REPO_PATH = f"/kaggle/working/{REPO_NAME}"

# ========== CLONE OR PULL REPO ==========
if not os.path.exists(REPO_PATH):
    print(f"ðŸ“Œ Cloning branch '{REPO_BRANCH}' from {REPO_URL}...")
    # Clone specifically the chosen branch
    subprocess.run(["git", "clone", "-b", REPO_BRANCH, REPO_URL], check=True)
else:
    print(f"ðŸ“Œ Repository exists. Fetching updates for branch '{REPO_BRANCH}'...")
    os.chdir(REPO_PATH)
    
    # Reset local changes and fetch all updates
    subprocess.run(["git", "reset", "--hard"], check=True)
    subprocess.run(["git", "fetch", "--all"], check=True)
    
    # Ensure we are on the correct branch and pull the latest changes
    subprocess.run(["git", "checkout", REPO_BRANCH], check=True)
    subprocess.run(["git", "pull", "origin", REPO_BRANCH], check=True)
    
    os.chdir("/kaggle/working")

# ========== SETUP PATHS & REQUIREMENTS ==========
# Add the src directory where your files are located
sys.path.append(f"{REPO_PATH}/src")

# Install requirements if the file exists in your repo
req_file = f"{REPO_PATH}/requirements.txt"
if os.path.exists(req_file):
    print("Installing requirements...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", req_file], check=True)

print(f"âœ… Repository (Branch: {REPO_BRANCH}) is ready and paths are configured.")

ðŸ“Œ Cloning branch 'main' from https://github.com/MuhammadQaiser1921/swin-model.git...


Cloning into 'swin-model'...


Installing requirements...
âœ… Repository (Branch: main) is ready and paths are configured.


In [2]:
from train_video import load_and_prepare_data

# Load data into memory once
# Set max_images to a number (e.g., 500) for a quick test if needed
data = load_and_prepare_data(max_images=None) 

print(f"\nðŸ“Š Data Preparation Complete:")
print(f"   Training samples: {len(data['train_paths'])}")
print(f"   Validation samples: {len(data['val_paths'])}")

2026-02-28 20:56:32.100648: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1772312192.561077      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1772312192.698929      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1772312193.821865      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1772312193.821910      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1772312193.821913      55 computation_placer.cc:177] computation placer alr

ImportError: cannot import name 'build_model' from 'swin_transformer' (/kaggle/working/swin-model/src/swin_transformer.py)

In [None]:
import importlib
import swin_transformer
import train_video

# Force reload the modules to pick up any changes from the repository
importlib.reload(swin_transformer)
importlib.reload(train_video)

# Import the function and the Config class directly from train_video
from train_video import run_training_session, Config

# Call the training function using parameters from the Config class
# Ensure you use the uppercase attribute names as defined in your Config class
model, history = run_training_session(
    data, 
    epochs=Config.epochs, 
    batch_size=Config.batch_size, 
    lr=Config.lr
)

print("\nâœ… Training session completed with the latest repository code.")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.title('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.title('Loss')
plt.legend()
plt.show()