In [3]:
import torch
from vit_model import VisionTransformer

# --- Configuration ---
# IMPORTANT: Change this path to where your downloaded model file is located.
PRETRAINED_MODEL_PATH = '../weights/transformer_120.pth' 

# IMPORTANT: Adjust these parameters to match the model you downloaded.
# These should match the configuration used to train the pretrained model.
# If you don't know them, the defaults (e.g., 'base' ViT) are a good start.
IMG_SIZE = 224
PATCH_SIZE = 16
EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 12
NUM_CLASSES = 1000 # This might be different for a ReID model, but it's a starting point.

# --- Model and Loading ---

# 1. Instantiate the model with the correct architecture
#    The model structure in code must match the structure of the saved weights.
model = VisionTransformer(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    depth=DEPTH,
    n_heads=NUM_HEADS,
    n_classes=NUM_CLASSES,
)

print("--- Attempting to load pretrained weights ---")
print(f"Loading from: {PRETRAINED_MODEL_PATH}")

try:
    # 2. Load the downloaded weights
    #    map_location='cpu' ensures the model loads even if you don't have a GPU
    state_dict = torch.load(PRETRAINED_MODEL_PATH, map_location='cpu')
    
    # Often, pretrained models are saved inside a dictionary with a key like 'model' or 'state_dict'
    # If the direct load fails, you might need to inspect the keys of the loaded dictionary.
    # For example: if 'model' in state_dict: state_dict = state_dict['model']

    # 3. Load the weights into the model
    model.load_state_dict(state_dict, strict=False) # Use strict=False to be more lenient
    
    print("\n[SUCCESS] Pretrained weights loaded successfully!")

    print("\n--- Debugging Mismatched Keys ---")
    # Load the state dict again for inspection
    checkpoint = torch.load(PRETRAINED_MODEL_PATH, map_location='cpu')

    # It might be nested, check for common keys
    if 'model' in checkpoint:
        checkpoint = checkpoint['model']
    if 'state_dict' in checkpoint:
        checkpoint = checkpoint['state_dict']

    print("\nKeys in the PRETRAINED FILE:")
    for key in checkpoint.keys():
        print(key)

    print("\nKeys in the MODEL ARCHITECTURE:")
    for key in model.state_dict().keys():
        print(key)
        
    print("\n--- End of Debugging ---")

    # Number of parameters in the loaded model and the model architecture
    num_params_loaded = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal trainable parameters in the loaded model: {num_params_loaded}")

    # Number of parameters in the model architecture
    num_params_architecture = sum(p.numel() for p in model.state_dict().values())
    print(f"Total parameters in the model architecture: {num_params_architecture}")

    if num_params_loaded != num_params_architecture:
        print("\n[WARNING] The number of parameters in the loaded model does not match the model architecture.")
        print(f"Loaded parameters: {num_params_loaded}, Architecture parameters: {num_params_architecture}")
        
except Exception as e:
    print(f"\n[ERROR] Failed to load pretrained weights. Reason: {e}")
    print("\n--- Troubleshooting ---")
    print("This error usually means the model architecture in the code does not match the architecture in the saved file.")
    print("To debug, you can print the keys from both your model and the file to see how they differ.")
    # --- Code to add inside the 'except' block for debugging ---
    
# --- Verification ---

# 4. Set the model to evaluation mode
#    This is crucial for getting correct predictions (it disables layers like Dropout)
model.eval()
print("\nModel set to evaluation mode.")

# 5. Perform a forward pass with a dummy image to ensure the model runs
try:
    with torch.no_grad(): # Disable gradient calculation for inference
        dummy_image = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)
        output = model(dummy_image)
    
    print("\n--- Verification ---")
    print("[SUCCESS] Forward pass with loaded weights completed without errors.")
    print(f"Input shape:  {dummy_image.shape}")
    print(f"Output shape: {output.shape}")

except Exception as e:
    print(f"\n[ERROR] An error occurred during the forward pass after loading weights: {e}")

--- Attempting to load pretrained weights ---
Loading from: ../weights/transformer_120.pth

[SUCCESS] Pretrained weights loaded successfully!

--- Debugging Mismatched Keys ---

Keys in the PRETRAINED FILE:
base.cls_token
base.pos_embed
base.patch_embed.conv.0.weight
base.patch_embed.conv.1.IN.weight
base.patch_embed.conv.1.IN.bias
base.patch_embed.conv.1.BN.weight
base.patch_embed.conv.1.BN.bias
base.patch_embed.conv.1.BN.running_mean
base.patch_embed.conv.1.BN.running_var
base.patch_embed.conv.1.BN.num_batches_tracked
base.patch_embed.conv.3.weight
base.patch_embed.conv.4.IN.weight
base.patch_embed.conv.4.IN.bias
base.patch_embed.conv.4.BN.weight
base.patch_embed.conv.4.BN.bias
base.patch_embed.conv.4.BN.running_mean
base.patch_embed.conv.4.BN.running_var
base.patch_embed.conv.4.BN.num_batches_tracked
base.patch_embed.conv.6.weight
base.patch_embed.conv.7.weight
base.patch_embed.conv.7.bias
base.patch_embed.conv.7.running_mean
base.patch_embed.conv.7.running_var
base.patch_embed.conv

In [2]:
import cv2 as cv
from ultralytics import YOLO
video_file = '../v_0/input/3c.mp4'

# Take frame number 100 from the video
cap = cv.VideoCapture(video_file)
frame_number = 110
cap.set(cv.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
    print(f"[ERROR] Could not read frame {frame_number} from video {video_file}.")
else:
    print(f"[SUCCESS] Successfully read frame {frame_number} from video {video_file}.")

    # Using Yolo detect the person and save the crops
    model = YOLO('../weights/yolo11m.pt')  # Load the YOLO model
    results = model(frame)  # Perform inference on the frame

    # Extract the bounding boxes for persons
    person_boxes = []
    for result in results:
        for box in result.boxes:
            if box.cls == 0:  # Class 0 is typically 'person' in YOLO
                person_boxes.append(box.xyxy.cpu().numpy())
    if not person_boxes:
        print("[WARNING] No persons detected in the frame.")
    else:
        print(f"[SUCCESS] Detected {len(person_boxes)} persons in the frame.")

    # Crop and save the detected persons
    for i, box in enumerate(person_boxes):
        print(f"[INFO] Cropping person {i} with bounding box: {box}")
        x1, y1, x2, y2 = map(int, box.flatten())
        person_crop = frame[y1:y2, x1:x2]
        cv.imwrite(f'person_crop_{i}.jpg', person_crop)
        print(f"[SUCCESS] Saved cropped person {i} to 'person_crop_{i}.jpg'.")

[SUCCESS] Successfully read frame 110 from video ../v_0/input/3c.mp4.

0: 384x640 3 persons, 170.6ms
Speed: 7.6ms preprocess, 170.6ms inference, 11.6ms postprocess per image at shape (1, 3, 384, 640)
[SUCCESS] Detected 3 persons in the frame.
[INFO] Cropping person 0 with bounding box: [[     807.12      222.78      895.39      486.35]]
[SUCCESS] Saved cropped person 0 to 'person_crop_0.jpg'.
[INFO] Cropping person 1 with bounding box: [[     741.28      118.91      792.43      261.58]]
[SUCCESS] Saved cropped person 1 to 'person_crop_1.jpg'.
[INFO] Cropping person 2 with bounding box: [[     788.63         112      837.43      254.65]]
[SUCCESS] Saved cropped person 2 to 'person_crop_2.jpg'.
