In [1]:
# load model

import glob
import os
import cv2
import torch
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from timm import create_model
import torch.nn.functional as F
from skimage import io, transform
import matplotlib.pyplot as plt

# Paths
video_path = "/Users/annastuckert/Documents/GitHub/facemap/cam1_G7c1_1_10seconds.avi"
output_video_path = "output_video_with_keypoints_224x224.mp4"
model_path = "models/best_model.pt"

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model loading function
def load_model(model_path, num_output_classes=24):
    model = create_model("vit_base_patch16_224", pretrained=False, in_chans=1, num_classes=num_output_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

# Load model
try:
    model = load_model(model_path)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    exit()


Model loaded successfully.


In [None]:
# import cv2
# import os

# # Define video path and output folder
# video_path = '/Users/annastuckert/Documents/GitHub/facemap/cam1_G7c1_1_10seconds.avi'  # Update with the path to your video
# output_folder = 'frames'  # Folder to save frames

# # Ensure the output folder exists
# os.makedirs(output_folder, exist_ok=True)

# # Set up video capture
# cap = cv2.VideoCapture(video_path)
# fps = int(cap.get(cv2.CAP_PROP_FPS))
# frame_count = 0

# # Process each frame in the video
# while cap.isOpened():
#     ret, frame = cap.read()
#     if not ret:
#         break
    
#     # Resize the frame to 224x224
#     #frame = cv2.resize(frame, (224, 224))

#     # Save each frame as an image in the output folder
#     frame_filename = os.path.join(output_folder, f"frame_{frame_count:04d}.jpg")
#     cv2.imwrite(frame_filename, frame)

#     frame_count += 1

# # Release resources
# cap.release()
# print(f"Frames saved to '{output_folder}' folder with 224x224 resolution.")


In [8]:
# load packages
import glob
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
from skimage import transform,io
import os
import pdb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# where are pictures?
IMG_LOC='/Users/annastuckert/Documents/GitHub/facemap/frames/'

# make folder called low_res if not exist
if os.path.isdir(IMG_LOC+'low_res'):
    print('Folder exists!')
else:
    os.makedirs(IMG_LOC+'low_res')

# find all pngs
img_files = sorted(glob.glob(IMG_LOC+'*.jpg'))
# find labels
#labels = pd.read_csv(IMG_LOC+'labels.csv')

# specify height and width
h = w = 224

# read one image
img = plt.imread(img_files[0])

# find aspects of original image
h_org = img.shape[0]
w_org = img.shape[1]

Folder exists!


In [14]:
import glob
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from skimage import io, transform, color

# Define the low-res folder and video output path
low_res_folder = '/Users/annastuckert/Documents/GitHub/facemap/frames/low_res'
os.makedirs(low_res_folder, exist_ok=True)
video_output_path = '/Users/annastuckert/Documents/GitHub/facemap/output_video.avi'  # Adjust the output path

# Initialize transformation for the video frames
vit_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
])

target_height = target_width = 224

# Get list of image files

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model should be defined and loaded here
# model = YourModelClass().to(device)
# model.load_state_dict(torch.load('your_model.pth'))
# model.eval()

# List to hold frames for video
frames = []

# Process each image file
for img_file in img_files:  # Limit to first 10 images or adjust as necessary
    # Read the image
    img = io.imread(img_file)

    # Get original dimensions
    h_org, w_org = img.shape[:2]
    
    # Determine the center crop region
    start_x = (w_org - h_org) // 2  # Horizontal start point
    start_y = 0                     # Vertical start point (full height)
    
    # Crop the center square
    cropped_img = img[start_y:start_y + h_org, start_x:start_x + h_org]

    # Convert to grayscale
    gray_img = color.rgb2gray(cropped_img)

    # Resize the image to 224x224 pixels
    resized_img = transform.resize(gray_img, (target_height, target_width), anti_aliasing=True)

    # Convert grayscale to RGB for drawing (3 channels)
    color_img = np.stack([resized_img] * 3, axis=-1)  # Create an RGB version of the grayscale image

    # Add key point inference
    input_frame = vit_transform(resized_img).unsqueeze(0).to(device)  # Add batch dimension

    # Model inference for keypoints
    with torch.no_grad():
        scores = F.softplus(model(input_frame))
        keypoints = scores.squeeze().cpu().numpy()

    # Overlay keypoints on the color image
    for i in range(0, len(keypoints), 2):
        x, y = int(keypoints[i]), int(keypoints[i + 1])
        # Draw a red circle on the RGB image
        cv2.circle(color_img, (x, y), radius=5, color=(1, 0, 0), thickness=-1)

    # Save the RGB image with keypoints to the low_res folder
    base_filename = os.path.basename(img_file)
    save_path = os.path.join(low_res_folder, base_filename)
    io.imsave(save_path, (color_img * 255).astype(np.uint8))  # Convert back to uint8 for saving

    # Add processed frame to the frames list for video
    frames.append((color_img * 255).astype(np.uint8))  # Convert to uint8 for video

# Create a video from the frames
if frames:
    height, width, _ = frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can change the codec if necessary
    video_writer = cv2.VideoWriter(video_output_path, fourcc, 30, (width, height))  # 30 fps

    for frame in frames:
        video_writer.write(frame)

    video_writer.release()  # Finalize the video file
    print(f"Video saved to '{video_output_path}'.")

print(f"Processed images saved to '{low_res_folder}'.")
