In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights

import numpy as np
import pandas as pd
from fastparquet import write

import os
import glob as glob

import cv2
from PIL import Image


%matplotlib inline

In [2]:
class FinalLayer(nn.Module):
    """Modified last layer for resnet50 for your dataset"""
    def __init__(self):
        super(FinalLayer, self).__init__()
        self.fc = nn.Linear(2048, 12)  # Assuming you have 12 output classes
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.fc(x)
        out = self.sigmoid(out)
        return out

def modified_resnet50():
    # Load pretrained resnet50 with a modified last fully connected layer
    model = resnet50(weights=ResNet50_Weights.DEFAULT)
    model.fc = FinalLayer()
    return model

# Load the modified ResNet-50 model
model = modified_resnet50()

# Load the protest prediction model
model_checkpoint = torch.load('../../protest-detection-violence-estimation/model_best.pth.tar')
model.load_state_dict(model_checkpoint['state_dict'])
model.eval()

if torch.cuda.is_available():
    model = model.to('cuda')


In [3]:
# Define image transformations
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the model's input size
    transforms.ToTensor(),           # Convert to tensor
])

#### Accessing images

In [24]:
def protest_inference(path_to_img:str):
    # Load your input image
    image = Image.open(path_to_img)

    # Preprocess the image
    input_tensor = preprocess(image)

    # Add a batch dimension (1 image)
    input_tensor = input_tensor.unsqueeze(0)

    # Move all to cuda 
    if torch.cuda.is_available():
        input_tensor = input_tensor.to('cuda')

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)

    output = output.to('cpu')

    # Convert the output tensor to list
    output_list = output[0].tolist()
    
    return output_list

def run_detection_directory(path_to_res_file:str, country:str, test_data_dir: str):
    c = 0
    f = 0
    
    # File to store the results (parqur file)
    file_path = path_to_res_file+f'results_{country}.parquet'
    
    # Access the channel folder 
    channel_dirs = [filename for filename in os.listdir(test_data_dir) if not filename.startswith('.')]
    
    for channel_d in channel_dirs: # Get each channel 
        channel_name = channel_d
        
        sub_folder_path = os.path.join(test_data_dir, channel_d)
        
        # Access the video's frames folder 
        videos_dir = [filename for filename in os.listdir(sub_folder_path) if not filename.startswith('.')]
        
        for video_d in videos_dir:
            video_name = video_d
            
            video_path = os.path.join(sub_folder_path, video_d)
            
            frames = [f for f in os.listdir(video_path) if not f.startswith('.')] # Get the frames 
            
            for frame in frames:
                frame_path = os.path.join(video_path, frame)
                
                output_list = protest_inference(frame_path)
                
                # Extract country, channel name, video, and frame from the file path
                parts = frame_path.split('/')
                channel_name = parts[-3]
                video_name = parts[-2]
                frame_name = parts[-1]
                
                # Append the data to the list
                #data.append([channel_name, video_name, frame_name] + output_list)
            
                # Write the labels to a data frame
                aux = [channel_name, video_name, frame_name] + output_list
                data = [aux]
                df = pd.DataFrame(data, columns=['Channel', 'Video', 'Frame', "protest", "violence", "sign", "photo", "fire", "police", "children", "group_20", "group_100", "flag", "night", "shouting"])
                       
                # If the file does not exists, create it
                if not os.path.isfile(file_path): 
                    write(file_path, df)
                else: # Otherwise, write on it
                    write(file_path, df, append=True)
                
                f+=1
                if f>2:
                    break
            
            c +=1
            if c>1:
                break
                    
    
    return file_path

In [3]:
# Tree: sweden > channel > video > frames 

test_swe_path = '/zpool/beast-mirror/labour-movements-mobilisation-via-visual-means/youtube_video_frames/sweden/'
res_path = '/zpool/beast-mirror/labour-movements-mobilisation-via-visual-means/protest_derection_results/'
# Call the function and store the resulting DataFrame
#run_detection_directory(res_path, 'sweden', test_swe_path)

In [6]:
#data_frame

Unnamed: 0,Channel,Video,Frame,protest,violence,sign,photo,fire,police,children,group_20,group_100,flag,night,shouting
0,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_1000.jpg,0.000225,0.082358,0.968820,0.052513,0.000126,0.000292,0.105370,0.021472,0.000812,0.002075,0.001571,0.001512
1,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_500.jpg,0.011258,0.296230,0.772745,0.009518,0.019400,0.030011,0.005413,0.350340,0.054107,0.023425,0.008701,0.006307
2,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_1250.jpg,0.002343,0.130016,0.936888,0.033643,0.000915,0.000817,0.002302,0.297446,0.039660,0.008272,0.017428,0.002822
3,TCOSverige,9K-XapzK6Ps,9K-XapzK6Ps_250.jpg,0.000323,0.165700,0.885031,0.021366,0.000516,0.001810,0.008540,0.753946,0.039261,0.007584,0.035628,0.002676
4,Boost2013,XVEAnHeWFW8,XVEAnHeWFW8_1250.jpg,0.004670,0.219886,0.954517,0.144653,0.002254,0.004844,0.011936,0.278515,0.012142,0.014497,0.008824,0.024927
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84,Lararforbundet,0_SSqMq2sq0,0_SSqMq2sq0_1200.jpg,0.000619,0.081016,0.991842,0.090726,0.000083,0.000180,0.066205,0.115498,0.005357,0.001675,0.005137,0.000989
85,lararforbundetkarlskrona590,w0KJNnGde8o,w0KJNnGde8o_2400.jpg,0.013474,0.184807,0.780389,0.074723,0.006902,0.007025,0.098665,0.202513,0.008689,0.022039,0.001305,0.009016
86,visionkanalen,YP___XatUsA,YP___XatUsA_250.jpg,0.000467,0.241174,0.893859,0.044116,0.013640,0.010297,0.012991,0.493185,0.077919,0.012357,0.029812,0.012524
87,Sverigesingenjorer,_C1IVrOa1Qg,_C1IVrOa1Qg_1500.jpg,0.001029,0.223190,0.724311,0.044494,0.001537,0.005074,0.001450,0.726003,0.074601,0.054703,0.007514,0.004565


In [4]:
pd.read_parquet(res_path+'results_sweden.parquet', engine='fastparquet')

Unnamed: 0,Channel,Video,Frame,protest,violence,sign,photo,fire,police,children,group_20,group_100,flag,night,shouting
0,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_1000.jpg,0.000225,0.082358,0.968820,0.052513,0.000126,0.000292,0.105370,0.021472,0.000812,0.002075,0.001571,0.001512
1,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_500.jpg,0.011258,0.296230,0.772745,0.009518,0.019400,0.030011,0.005413,0.350340,0.054107,0.023425,0.008701,0.006307
2,TCOSverige,vD9AD9Mkn00,vD9AD9Mkn00_1250.jpg,0.002343,0.130016,0.936888,0.033643,0.000915,0.000817,0.002302,0.297446,0.039660,0.008272,0.017428,0.002822
3,TCOSverige,9K-XapzK6Ps,9K-XapzK6Ps_250.jpg,0.000323,0.165700,0.885031,0.021366,0.000516,0.001810,0.008540,0.753946,0.039261,0.007584,0.035628,0.002676
4,Boost2013,XVEAnHeWFW8,XVEAnHeWFW8_1250.jpg,0.004670,0.219886,0.954517,0.144653,0.002254,0.004844,0.011936,0.278515,0.012142,0.014497,0.008824,0.024927
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Lararforbundet,0_SSqMq2sq0,0_SSqMq2sq0_1200.jpg,0.000619,0.081016,0.991842,0.090726,0.000083,0.000180,0.066205,0.115498,0.005357,0.001675,0.005137,0.000989
261,lararforbundetkarlskrona590,w0KJNnGde8o,w0KJNnGde8o_2400.jpg,0.013474,0.184807,0.780389,0.074723,0.006902,0.007025,0.098665,0.202513,0.008689,0.022039,0.001305,0.009016
262,visionkanalen,YP___XatUsA,YP___XatUsA_250.jpg,0.000467,0.241174,0.893859,0.044116,0.013640,0.010297,0.012991,0.493185,0.077919,0.012357,0.029812,0.012524
263,Sverigesingenjorer,_C1IVrOa1Qg,_C1IVrOa1Qg_1500.jpg,0.001029,0.223190,0.724311,0.044494,0.001537,0.005074,0.001450,0.726003,0.074601,0.054703,0.007514,0.004565
