In [1]:
import pandas as pd
from potnet import *

In [2]:
high_model = load_model('hf_potnet_model_v3.pt')
mid_model = load_model('hf_potnet_mid.pt')

Model loaded from hf_potnet_model_v3.pt
Model loaded from hf_potnet_mid.pt


In [3]:
mid_rows = 271634 - 37247
high_rows = 271634 - 15166
only_high_data = high_model.generate(high_rows)
only_mid_data = mid_model.generate(mid_rows)

In [4]:
only_high_data.downloads_category.value_counts()

downloads_category
High    256468
Name: count, dtype: int64

In [5]:
only_mid_data.downloads_category.value_counts()

downloads_category
Mid    234387
Name: count, dtype: int64

In [6]:
df = pd.read_csv('data/hf_models_withmodelcard_nov2024.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1191759 entries, 0 to 1191758
Data columns (total 24 columns):
 #   Column               Non-Null Count    Dtype 
---  ------               --------------    ----- 
 0   model_id             1191759 non-null  object
 1   num_downloads        1191759 non-null  int64 
 2   num_likes            1191759 non-null  int64 
 3   is_private           1191759 non-null  bool  
 4   task                 1191759 non-null  object
 5   tags                 1191759 non-null  object
 6   author               1191759 non-null  object
 7   author_category      1191759 non-null  object
 8   base_model_relation  322 non-null      object
 9   base_model           269144 non-null   object
 10  language             1191759 non-null  object
 11  model_creator        6528 non-null     object
 12  model_type           4578 non-null     object
 13  model_name           6433 non-null     object
 14  model_card_tags      376770 non-null   object
 15  datasets       

In [7]:
import ast
df['tags'] = df['tags'].apply(ast.literal_eval)
df['location'] = df['tags'].apply(
    lambda tags: next((tag.split(':', 1)[1] for tag in tags if tag.startswith('region:')), None)
)

In [8]:
df = df[['task_group', 'author_category', 'language_category', 'downloads_category', 'location']]
df.downloads_category.value_counts()

downloads_category
Very Low    867712
Low         271634
Mid          37247
High         15166
Name: count, dtype: int64

In [9]:
combined_data = pd.concat([only_mid_data, only_high_data], ignore_index=True)
combined_data.downloads_category.value_counts()

downloads_category
High    256468
Mid     234387
Name: count, dtype: int64

In [10]:
combined_data_v2 = pd.concat([df, combined_data], ignore_index=True)
combined_data_v2.downloads_category.value_counts()

downloads_category
Very Low    867712
Low         271634
High        271634
Mid         271634
Name: count, dtype: int64

# Stratified sampling

Remove rows from Very Low (apply stratified sampling, and try to keep equal no. of samples for each task group where downloads category is very low)

In [11]:
# Determine the minimum count across downloads_category buckets
min_count = combined_data_v2['downloads_category'].value_counts().min()
print("Minimum count for downloads_category:", min_count)

def stratified_sampling_within_category(group, target):
    # Get unique task groups in this downloads_category
    unique_tasks = group['task_group'].unique()
    n_tasks = len(unique_tasks)
    # Determine target sample size per task_group (using integer division)
    sample_per_task = target // n_tasks
    sampled_frames = []
    for task in unique_tasks:
        sub_df = group[group['task_group'] == task]
        # If sub_df has fewer rows than sample_per_task, sample with replacement; otherwise without
        if len(sub_df) < sample_per_task:
            sampled = sub_df.sample(n=sample_per_task, replace=True, random_state=42)
        else:
            sampled = sub_df.sample(n=sample_per_task, random_state=42)
        sampled_frames.append(sampled)
    # Combine the samples from all task groups within this downloads_category
    return pd.concat(sampled_frames)

# Apply stratified sampling per downloads_category group
sampled_df = combined_data_v2.groupby('downloads_category', group_keys=False).apply(lambda g: stratified_sampling_within_category(g, min_count))

print("Sample counts by downloads_category:")
print(sampled_df['downloads_category'].value_counts())

print("Within each downloads_category, task_group counts:")
print(sampled_df.groupby('downloads_category')['task_group'].value_counts())

Minimum count for downloads_category: 271634


  sampled_df = combined_data_v2.groupby('downloads_category', group_keys=False).apply(lambda g: stratified_sampling_within_category(g, min_count))


Sample counts by downloads_category:
downloads_category
High        271632
Low         271629
Mid         271629
Very Low    271629
Name: count, dtype: int64
Within each downloads_category, task_group counts:
downloads_category  task_group                    
High                Audio Processing                  33954
                    Data Analysis & Classification    33954
                    Image Processing                  33954
                    Multimodal Processing             33954
                    Specialized Applications          33954
                    Text Processing                   33954
                    Unknown                           33954
                    Video Processing                  33954
Low                 Audio Processing                  30181
                    Data Analysis & Classification    30181
                    Image Processing                  30181
                    Multimodal Processing             30181
                    

# There are no rows where downloads category is high and task_group is 'Other' 
hence we have low

In [12]:
df[(df.task_group == 'Other') & (df.downloads_category == 'high')]

Unnamed: 0,task_group,author_category,language_category,downloads_category,location


In [13]:
df[(df.task_group == 'Other') & (df.downloads_category == 'Mid')]

Unnamed: 0,task_group,author_category,language_category,downloads_category,location
498655,Other,Gold,High,Mid,us
595554,Other,Gold,High,Mid,us


In [34]:
combined_data_v2.to_csv('data/generated_data/hf_11_24_generated.csv', index=False)
