In [1]:
import random
import pycountry
import pandas as pd
from functools import lru_cache
from PIL import Image, UnidentifiedImageError

import os
import gzip
import shutil
import asyncio
import aiohttp
import nest_asyncio
from io import BytesIO
from tqdm.asyncio import tqdm_asyncio
from urllib.parse import quote, urlsplit, urlunsplit

In [3]:
file_path = 'links.tsv.gz'

with gzip.open(file_path, 'rt', encoding='utf-8') as file:
    df = pd.read_csv(file, sep='\t')

df

Unnamed: 0,id,label,nationality,date_of_birth,year_of_birth,image
0,Q261019,Isabelle Boulay,canada,1972-07-06T00:00:00Z,1972,https://upload.wikimedia.org/wikipedia/commons...
1,Q261049,Florencia Pe√±a,argentina,1974-11-07T00:00:00Z,1974,https://upload.wikimedia.org/wikipedia/commons...
2,Q261064,Colin Fleming,united_kingdom,1984-08-13T00:00:00Z,1984,https://upload.wikimedia.org/wikipedia/commons...
3,Q261069,Stig Inge Bj√∏rnebye,norway,1969-12-11T00:00:00Z,1969,https://upload.wikimedia.org/wikipedia/commons...
4,Q261123,Alireza Haghi,iran,1979-02-08T00:00:00Z,1979,https://upload.wikimedia.org/wikipedia/commons...
...,...,...,...,...,...,...
311596,Q61045974,Genc Sermaxhaj,kosovo,1988-12-28T00:00:00Z,1988,https://upload.wikimedia.org/wikipedia/commons...
311597,Q61093547,Alfonso M√©ndiz,spain,1961-07-21T00:00:00Z,1961,https://upload.wikimedia.org/wikipedia/commons...
311598,Q61163004,Carmen A√≠da Lazo,el_salvador,1976-01-03T00:00:00Z,1976,https://upload.wikimedia.org/wikipedia/commons...
311599,Q61315007,Christin Walker,united_states_of_america,1988-05-03T00:00:00Z,1988,https://upload.wikimedia.org/wikipedia/commons...


In [4]:
# –Ω–∞—á–Ω–µ–º —Å —É–º–Ω–æ–≥–æ –ø–æ–¥—Ö–æ–¥–∞: –ø–æ–∏—â–µ–º —Å–æ–≤–ø–∞–¥–µ–Ω–∏—è –≤ –±–∏–±–ª–∏–æ—Ç–µ–∫–µ pycountry
@lru_cache(maxsize=None)
def map_to_country(name):
    try:
        country = pycountry.countries.search_fuzzy(name)[0]
        return country.name
    except LookupError:
        return None

df['country'] = df['nationality'].apply(map_to_country)
manual_mapping = {
    'united_kingdom': 'United Kingdom',
    'kingdom_of_the_netherlands': 'Netherlands',
    'united_states_of_america': 'United States',
    'new_zealand': 'New Zealand',
    'socialist_federal_republic_of_yugoslavia': 'Yugoslavia',
    'czech_republic': 'Czech Republic',
    'south_africa': 'South Africa',
    "people's_republic_of_china": 'China',
    'south_korea': 'Korea', ###
    'soviet_union': 'Russian Federation',
    'republic_of_china_(1912‚Äì1949)': 'China',
    'turkey': 'Turkey',
    'bosnia_and_herzegovina': 'Bosnia and Herzegovina',
    'democratic_republic_of_the_congo': 'Democratic Republic of the Congo',
    'western_sahara': 'Western Sahara',
    'republic_of_macedonia': 'North Macedonia',
    'yugoslavia': 'Yugoslavia',
    'trinidad_and_tobago': 'Trinidad and Tobago',
    'ivory_coast': 'C√¥te d\'Ivoire',
    'greenland':'Canada',
    'puerto_rico': 'Puerto Rico',
    'northern_ireland': 'United Kingdom',
    'saudi_arabia': 'Saudi Arabia',
    'republic_of_the_congo': 'Congo',
    'palau':'Philippines', ###
    'montserrat':'Cuba', ###
    'north_korea': 'Korea', ###
    'antigua_and_barbuda': 'Antigua and Barbuda',
    'liechtenstein':'Israel', ###
    'czechoslovakia': 'Czech Republic',
    'sri_lanka': 'Sri Lanka',
    'san_marino': 'San Marino',
    'dominican_republic': 'Dominican Republic',
    'state_of_palestine': 'Palestine, State of',
    'burkina_faso': 'Burkina Faso',
    'federated_states_of_micronesia': 'Micronesia, Federated States of',
    'second_polish_republic': 'Poland',
    'costa_rica': 'Costa Rica',
    'british_raj': 'India',
    'german_democratic_republic': 'Germany',
    'united_arab_emirates': 'United Arab Emirates',
    'faroe_islands': 'United kingdom', ###
    'saint_kitts_and_nevis': 'Saint Kitts and Nevis',
    'hong_kong': 'China', ###
    "people's_republic_of_poland": 'Poland',
    'serbia_and_montenegro': 'Serbia',
    'nazi_germany': 'Germany',
    'el_salvador': 'El Salvador',
    'central_african_republic': 'Central African Republic',
    'kingdom_of_yugoslavia': 'Yugoslavia',
    'weimar_republic': 'Germany',
    'kingdom_of_denmark': 'Denmark',
    'saint_vincent_and_the_grenadines': 'Saint Vincent and the Grenadines',
    'papua_new_guinea': 'Papua New Guinea',
    'cape_verde': 'Cabo Verde',
    'palestinian_national_authority': 'Palestine, State of',
    'empire_of_japan': 'Japan',
    'sierra_leone': 'Sierra Leone',
    'east_timor': 'Indonesia', ###
    'russian_soviet_federative_socialist_republic': 'Russian Federation',
    'kingdom_of_serbs,_croatians_and_slovenes': 'Yugoslavia',
    'tibet': 'China',
    'mandatory_palestine': 'Palestine, State of',
    'kingdom_of_italy': 'Italy',
    'turkish_republic_of_northern_cyprus': 'Cyprus',
    'kingdom_of_romania': 'Romania',
    'guernsey':'United Kingdom', ###
    'equatorial_guinea': 'Equatorial Guinea',
    'transnistria': 'Romania',
    'great_britain': 'United Kingdom',
    'kingdom_of_iraq': 'Iraq',
    'south_sudan': 'Sudan',
    'ukrainian_soviet_socialist_republic': 'Ukraine',
    's√£o_tom√©_and_pr√≠ncipe': 'S√£o Tom√© and Pr√≠ncipe',
    'artsakh': 'Azerbaijan',
    'federal_republic_of_yugoslavia': 'Yugoslavia',
    'armenian_soviet_socialist_republic': 'Armenia',
    'kingdom_of_egypt': 'Egypt',
    'francoist_spain': 'Spain',
    'protectorate_of_bohemia_and_moravia': 'Czech Republic',
    'west_germany': 'Germany',
    'solomon_islands': 'Solomon Islands',
    'saint_lucia': 'Saint Lucia',
    'colonial_nigeria': 'Nigeria',
    'kingdom_of_hungary': 'Hungary',
    "people's_republic_of_hungary": 'Hungary',
    'south_vietnam': 'Vietnam',
    'isle_of_man': 'United Kingdom', ###
    'manchukuo': 'China',
    'laos': 'Laos',
    'czechoslovak_socialist_republic': 'Czech Republic',
    'british_hong_kong': 'China',
    'slovak_state_(1939-1945)': 'Slovakia',
    'kingdom_of_bulgaria': 'Bulgaria',
    "people's_republic_of_bulgaria": 'Bulgaria',
    'sultanate_of_zanzibar': 'Tanzania, United Republic of', ###
    'dutch_east_indies': 'Indonesia',
    'french_algeria': 'Algeria',
    'marshall_islands': 'Marshall Islands',
    'byelorussian_soviet_socialist_republic': 'Belarus',
    'japanese_people': 'Japan',
    'welsh_people': 'United Kingdom',
    'british_people': 'United Kingdom',
    'rhodesia': 'Zimbabwe',
    'hungarian': 'Hungary',
    'federation_of_rhodesia_and_nyasaland': 'Zimbabwe',
    'socialist_republic_of_romania': 'Romania',
    'kingdom_of_albania': 'Albania',
    'iraqi_kurdistan': 'Iraq',
    'union_of_south_africa': 'South Africa',
    'indian_people': 'India',
    'cook_islands': 'New Zealand', ###
    'niue': 'New Zealand', ###
    'georgian_soviet_socialist_republic': 'Georgia',
    'southern_rhodesia': 'Zimbabwe',
    'british_virgin_islands': 'United States',
    'american_samoa': 'United States',
    'danish': 'Denmark',
    'kingdom_of_afghanistan': 'Afghanistan',
    'first_republic_of_austria': 'Austria',
    'british_empire': 'United Kingdom',
    'kingdom_of_greece': 'Greece',
    'belgian_congo': 'Democratic Republic of the Congo',
    'macau': 'China',
    "people's_socialist_republic_of_albania": 'Albania',
    'yemen_arab_republic': 'Yemen',
    'vatican_city': 'Italy',
    'kenya_colony': 'Kenya',
    'tibet_from_1912_to_1951': 'China',
    'ruanda-urundi': 'Rwanda',
    'german_empire': 'Germany',
    'nepali': 'Nepal',
    'united_kingdom_of_great_britain_and_ireland': 'United Kingdom',
    'tuva_republic': 'Russian Federation',
    'austrians': 'Austria',
    'british_national_(overseas)': 'United Kingdom',
    'filipino_people': 'Philippines',
    'lithuanian_soviet_socialist_republic': 'Lithuania',
    'country_of_the_kingdom_of_the_netherlands': 'Netherlands',
    'netherlands_antilles': 'Netherlands',
    'republic_of_upper_volta': 'Burkina Faso',
    'first_portuguese_republic': 'Portugal',
    "romanian_people's_republic": 'Romania',
    "mongolian_people's_republic": 'Mongolia',
    'democratic_republic_of_georgia': 'Georgia',
    'azerbaijani': 'Azerbaijan',
    'bangladeshis': 'Bangladesh',
    'bulgarian': 'Bulgaria',
    'italians': 'Italy',
    'american_occupation_zone': 'Germany',
    'republic_of_cuba_(1902‚Äì59)': 'Cuba',
    'south_yemen': 'Yemen',
    'irish_republic': 'Ireland',
    'british_somaliland': 'Somalia',
    'chinese_taipei': 'China',
    'bosniaks': 'Bosnia and Herzegovina',
    'tibetan_people': 'China',
    'kingdom_of_mysore': 'India',
    'beiyang_government': 'China',
    'afrika': 'South Africa',
    'americans': 'United States',
    'chileans': 'Chile',
    'sint_maarten': 'Netherlands',
    'hungarians': 'Hungary',
    'norwegian': 'Norway',
    'irish': 'Ireland',
    'czechoslovak_republic': 'Czech Republic',
    'mexicana': 'Mexico',
    'cayman_islands': 'Cuba',
    's√£o_paulo': 'Brazil',
    'qu√©bec-comt√©': 'Canada',
    'israelis': 'Israel',
    'range_of_andia': 'Spain',
    'anguilla':'Cuba',
    'mar√≠timo': 'Portugal',
    'chilena': 'Chile',
    'canadian_french': 'Canada',
    'egyptians': 'Egypt',
    'francia': 'France',
    'ukrainians': 'Ukraine',
    'dominicana': 'Dominican Republic',
    'kurdistan': 'Turkey',
    'germans': 'Germany',
    'the_republic_of_abkhazia': 'Georgia', #### —É–ø—Å...
    'united_federation_of_planets': 'United States', # –ª–µ–≥–µ–Ω–¥–∞
    'katun': 'Russian Federation',
    'siciliana': 'Italy',
    'sovi√®tic': 'Russian Federation',
    'first_hungarian_republic': 'Hungary',
    'staffanstorp_municipality': 'Sweden',
    'nuu-chah-nulth': 'Canada',
    'croacia': 'Croatia',
    'liberland': 'Czech Republic',
    'spain_under_the_restoration': 'Spain',
    'venezolano.': 'Venezuela, Bolivarian Republic of',
    'estado_novo': 'Portugal',
    'ivanteyevskaya_street': 'Russian Federation',
    'kuwait_city': 'Kuwait',
    'florence': 'Italy',
    'monterrey': 'Mexico',
    'moldova':'Romania', ###
    'colombiana': 'Colombia',
    'ss_france': 'France',
    'francais_objective_specifique': 'France',
    'mexico_city': 'Mexico',
    'morocco_pavilion': 'Morocco',
    'brazil‚Äìuruguay_relations': 'Brazil',
    'ecuador_national_football_team': 'Ecuador',
    'langnau_am_albis': 'Switzerland',
    "federal_people's_republic_of_yugoslavia": 'Yugoslavia',
    'third_czechoslovak_republic': 'Czech Republic',
    'plastin': 'Romania',
    'nazareth': 'Israel',
    'korea':'Korea', ###
    'suisse_romande': 'Switzerland',
    'republika_srpska': 'Bosnia and Herzegovina',
    'san_luis_potos√≠': 'Mexico',
    'rep√∫blica_de_s√≠ria': 'Syrian Arab Republic',
    'tamil_eelam': 'Sri Lanka',
    'sockel_fm2+': 'Spain', # –Ω–µ–º–Ω–æ–≥–æ –Ω–µ–ø–æ–Ω—è—Ç–Ω–æ –ø—Ä–∏ —á–µ–º —Ç—É—Ç —Ä–æ–∑–µ—Ç–∫–∞...
    'canadian_nationality_law': 'Canada',
    'bicycle_kick': 'Chile',
    'santo_domingo': 'Dominican Republic',
    'qu√©b√©cois': 'Canada',
}
df['country'] = df['nationality'].map(manual_mapping).combine_first(df['country'])
df.loc[df["label"] == "Eliana Rubashkyn", "country"] = "Colombia"
df.loc[df["label"] == "Glen L Roberts", "country"] = "United States"
df.loc[df["label"] == "Denis P√©cic", "country"] = "France"
df = df.dropna(subset=['country'])

In [9]:
# —Å–ª–æ–≤–∞—Ä—å –¥–ª—è –≥—Ä—É–ø–ø–∏—Ä–æ–≤–∫–∏ —Ä–µ–¥–∫–∏—Ö —Å—Ç—Ä–∞–Ω
country_groups = {
    'Caribbean': [
        'Antigua and Barbuda', 'Bahamas', 'Barbados', 
        'Grenada', 'Saint Kitts and Nevis', 'Saint Lucia',
        'Saint Vincent and the Grenadines', 'Aruba', 'Bermuda',
        'Belize', 'Guyana'
    ],
    'Pacific Islands': [
        'Kiribati', 'Marshall Islands', 'Micronesia, Federated States of',
        'Nauru', 'Guam', 'Solomon Islands', 'Tuvalu', 'Vanuatu',
        'Papua New Guinea', 'Samoa'
    ],
    'African Small States': [
        'Botswana', 'Burundi', 'Cabo Verde', 'Comoros',
        'Djibouti', 'Equatorial Guinea', 'Eswatini', 'Gabon', 'Gambia', 'Eritrea',
        'Guinea-Bissau', 'Lesotho', 'Malawi', 'Mauritania', 'Mauritius', 'Mozambique',
        'S√£o Tom√© and Pr√≠ncipe', 'Seychelles', 'Togo', 'Madagascar',
        'Central African Republic', 'Chad', 'Sierra Leone', 'Liberia'                               
    ],
    'Central Asia': [
        'Kyrgyzstan', 'Tajikistan', 'Turkmenistan', 'Uzbekistan'
    ],
    'Middle East Small States': [
        'Brunei Darussalam', 'Oman', 'Qatar', 'Western Sahara', 'Yemen'
    ],
    'Other Europe': [
        'Gibraltar', 'Monaco', 'San Marino'
    ],
    'Other Asia': [
        'Bhutan', 'Maldives', 'Laos', 'Vietnam'
    ]
}

country_counts = df['country'].value_counts()
keep_individual = country_counts[country_counts > 49].index.tolist() #–Ω–∞–∑–≤–∞–Ω–∏—è —Å—Ç—Ä–∞–Ω –≥–¥–µ 50 –∏ –±–æ–ª—å—à–µ —Å—Ç—Ä–æ–∫ –Ω–µ –º–µ–Ω—è–µ–º
def group_country(country):
    if country in keep_individual:
        return country
    for group, countries in country_groups.items():
        if country in countries:
            return group
    return 'Other'  #–Ω–∞ –≤—Å—è–∫–∏–π —Å–ª—É—á–∞–π

df = df.copy()
df.loc[:, 'country'] = df['country'].apply(group_country)

In [12]:
dataset_root = "dataset_masha"
SEM_LIMIT = 40

def safe_url(url):
    try:
        parts = urlsplit(url)
        safe_path = quote(parts.path)
        return urlunsplit((parts.scheme, parts.netloc, safe_path, parts.query, parts.fragment))
    except:
        return url

async def is_url_accessible(session, url, semaphore):
    async with semaphore:
        try:
            url = safe_url(url)
            async with session.get(url, timeout=5) as resp:
                return resp.status == 200
        except:
            return False

async def download_image(session, row, index, semaphore, failed_rows):
    url = safe_url(row['image'])
    class_name = str(row['country']).strip().lower().replace(" ", "_")
    class_dir = os.path.join(dataset_root, class_name)
    os.makedirs(class_dir, exist_ok=True)

    async with semaphore:
        await asyncio.sleep(random.uniform(0.1, 0.3))

        try:
            async with session.get(url, timeout=5) as resp:
                if resp.status == 200:
                    content = await resp.read()
                    try:
                        image = Image.open(BytesIO(content)).convert('RGB')
                        image.save(os.path.join(class_dir, f"{index}.jpg"))
                    except UnidentifiedImageError:
                        failed_rows.append(row)
                        print(f"[{index}] –ù–µ—Ä–∞—Å–ø–æ–∑–Ω–∞–Ω–Ω—ã–π —Ñ–æ—Ä–º–∞—Ç –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è")
                else:
                    failed_rows.append(row)
                    print(f"[{index}] HTTP —Å—Ç–∞—Ç—É—Å: {resp.status}")
        except Exception as e:
            failed_rows.append(row)
            print(f"[{index}] –û—à–∏–±–∫–∞ —Å–∫–∞—á–∏–≤–∞–Ω–∏—è: {e}")

async def process_all(df):
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
        'Referer': 'https://google.com'
    }

    semaphore = asyncio.Semaphore(SEM_LIMIT)
    failed_rows = []

    async with aiohttp.ClientSession(headers=headers) as session:
        tasks_check = [
            is_url_accessible(session, url, semaphore) if pd.notna(url) else False
            for url in df['image']
        ]
        valid_mask = await tqdm_asyncio.gather(*tasks_check, position=0, desc="–ü—Ä–æ–≤–µ—Ä–∫–∞ URL")
        valid_mask = pd.Series(valid_mask, index=df.index)

        df_valid = df[valid_mask].reset_index(drop=True)
        df_failed_check = df[~valid_mask].reset_index(drop=True)


        tasks_download = [
            download_image(session, row, row["id"], semaphore, failed_rows)
            for _, row in df_valid.iterrows()
        ]
        await tqdm_asyncio.gather(*tasks_download, position=0, desc="–ó–∞–≥—Ä—É–∑–∫–∞ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π")

    df_failed_download = pd.DataFrame(failed_rows)
    df_failed_total = pd.concat([df_failed_check, df_failed_download], ignore_index=True)
    return df_valid, df_failed_total

async def process_all_in_batches(df, batch_size=10000, sleep_between_batches=30):
    total_batches = (len(df) + batch_size - 1) // batch_size
    all_valid, all_failed = [], []

    for i in range(total_batches):
        start, end = i * batch_size, min((i + 1) * batch_size, len(df))
        df_batch = df.iloc[start:end]
        print(f"\nüîπ –û–±—Ä–∞–±–æ—Ç–∫–∞ –±–∞—Ç—á–∞ {i+1}/{total_batches} ({start}‚Äì{end})")

        df_valid, df_failed = await process_all(df_batch)
        all_valid.append(df_valid)
        all_failed.append(df_failed)

        print(f"‚úÖ –ë–∞—Ç—á {i+1} –∑–∞–≤–µ—Ä—à—ë–Ω, –ø–∞—É–∑–∞ {sleep_between_batches} —Å–µ–∫...\n")
        await asyncio.sleep(sleep_between_batches)

    df_valid_final = pd.concat(all_valid, ignore_index=True)
    df_failed_final = pd.concat(all_failed, ignore_index=True)
    return df_valid_final, df_failed_final

In [13]:
nest_asyncio.apply()

In [20]:
# df_valid_1, df_fail_1 = await process_all_in_batches(df[:100000], batch_size=1000, sleep_between_batches=10)
# df_valid_1.to_csv("df_valid_1.csv", index=True)
# df_fail_1.to_csv("df_fail_1.csv", index=True)

In [18]:
# df_valid_2, df_fail_2 = await process_all_in_batches(df[100000:200000], batch_size=1000, sleep_between_batches=10)
# df_valid_2.to_csv("df_valid_2.csv", index=True)
# df_fail_2.to_csv("df_fail_2.csv", index=True)

In [16]:
# df_valid_3, df_fail_3 = await process_all_in_batches(df[200000:], batch_size=1000, sleep_between_batches=10)
# df_valid_3.to_csv("df_valid_3.csv", index=True)
# df_fail_3.to_csv("df_fail_3.csv", index=True)

In [2]:
df_valid_1 = pd.read_csv("df_valid_1.csv")
df_valid_2 = pd.read_csv("df_valid_2.csv")
df_valid_3 = pd.read_csv("df_valid_3.csv")
df_fail_1 = pd.read_csv("df_fail_1.csv")
df_fail_2 = pd.read_csv("df_fail_2.csv")
df_fail_3 = pd.read_csv("df_fail_3.csv")

In [3]:
df_valid = pd.concat([df_valid_1, df_valid_2, df_valid_3], ignore_index=True)
df_failed = pd.concat([df_fail_1, df_fail_2, df_fail_3], ignore_index=True)

In [12]:
df_failed["image"]

0        https://upload.wikimedia.org/wikipedia/commons...
1        https://upload.wikimedia.org/wikipedia/commons...
2        https://upload.wikimedia.org/wikipedia/commons...
3        https://upload.wikimedia.org/wikipedia/commons...
4        https://upload.wikimedia.org/wikipedia/commons...
                               ...                        
12837    https://upload.wikimedia.org/wikipedia/commons...
12838    https://upload.wikimedia.org/wikipedia/commons...
12839    https://upload.wikimedia.org/wikipedia/commons...
12840    https://upload.wikimedia.org/wikipedia/commons...
12841    https://upload.wikimedia.org/wikipedia/commons...
Name: image, Length: 12842, dtype: object

In [None]:
# shutil.make_archive("dataset_masha", 'zip', "dataset_masha")

In [17]:
# import os
# from sklearn.model_selection import train_test_split
# from torchvision.datasets import ImageFolder
# from torch.utils.data import Subset
#
# dataset = ImageFolder(root="dataset_masha")
#
# file_paths = [sample[0] for sample in dataset.samples]
# labels = [sample[1] for sample in dataset.samples]
#
#
# # –°–Ω–∞—á–∞–ª–∞ —Ä–∞–∑–¥–µ–ª—è–µ–º –Ω–∞ train+val (80%) –∏ test (20%)
# train_val_files, test_files, train_val_labels, test_labels = train_test_split(
#     file_paths,
#     labels,
#     test_size=0.2,
#     stratify=labels,
#     random_state=30
# )
# # –†–∞–∑–¥–µ–ª—è–µ–º train_val –Ω–∞ train –∏ val
# train_files, val_files, train_labels, val_labels = train_test_split(
#     train_val_files,
#     train_val_labels,
#     test_size=0.2,
#     stratify=train_val_labels,
#     random_state=30
# )
#
# # –ó–∞–≥—Ä—É–∂–∞–µ–º –ø–æ–ª–Ω—ã–π –¥–∞—Ç–∞—Å–µ—Ç —Å –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏—è–º–∏
# full_dataset = ImageFolder(root="dataset_masha", transform=None)
#
# # –°–æ–∑–¥–∞–µ–º —Å–ª–æ–≤–∞—Ä—å –¥–ª—è –∏–Ω–¥–µ–∫—Å–æ–≤ (—á—Ç–æ–±—ã —Å–æ–ø–æ—Å—Ç–∞–≤–∏—Ç—å –ø—É—Ç–∏ —Å –∏–Ω–¥–µ–∫—Å–∞–º–∏ –≤ full_dataset)
# file_to_index = {os.path.normpath(path): idx for idx, (path, _) in enumerate(full_dataset.samples)}
#
# # –ü–æ–ª—É—á–∞–µ–º –∏–Ω–¥–µ–∫—Å—ã –¥–ª—è train, val, test
# train_indices = [file_to_index[os.path.normpath(path)] for path in train_files]
# val_indices = [file_to_index[os.path.normpath(path)] for path in val_files]
# test_indices = [file_to_index[os.path.normpath(path)] for path in test_files]
#
# # –°–æ–∑–¥–∞–µ–º Subset –¥–ª—è –∫–∞–∂–¥–æ–π —á–∞—Å—Ç–∏
# train_dataset = Subset(full_dataset, train_indices)
# val_dataset = Subset(full_dataset, val_indices)
# test_dataset = Subset(full_dataset, test_indices)
#
# # –°–æ–∑–¥–∞–µ–º –ø–æ–¥–¥–∏—Ä–µ–∫—Ç–æ—Ä–∏–∏ (–µ—Å–ª–∏ –Ω—É–∂–Ω–æ —Å–æ—Ö—Ä–∞–Ω–∏—Ç—å —Ä–∞–∑–¥–µ–ª–µ–Ω–Ω—ã–µ –¥–∞–Ω–Ω—ã–µ)
# os.makedirs("split_dataset_masha/train", exist_ok=True)
# os.makedirs("split_dataset_masha/val", exist_ok=True)
# os.makedirs("split_dataset_masha/test", exist_ok=True)
#
# # –ö–æ–ø–∏—Ä—É–µ–º —Ñ–∞–π–ª—ã –≤ —Å–æ–æ—Ç–≤–µ—Ç—Å—Ç–≤—É—é—â–∏–µ –ø–∞–ø–∫–∏ (–º–æ–∂–Ω–æ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å shutil)
# import shutil
#
# def copy_files(files, target_dir):
#     for file in files:
#         class_name = os.path.basename(os.path.dirname(file))
#         dest_dir = os.path.join(target_dir, class_name)
#         os.makedirs(dest_dir, exist_ok=True)
#         shutil.copy(file, dest_dir)
#
# copy_files(train_files, "split_dataset_masha/train")
# copy_files(val_files, "split_dataset_masha/val")
# copy_files(test_files, "split_dataset_masha/test")
#
# # –¢–µ–ø–µ—Ä—å –º–æ–∂–Ω–æ –∑–∞–≥—Ä—É–∑–∏—Ç—å –∏—Ö —á–µ—Ä–µ–∑ ImageFolder
# train_dataset = ImageFolder("split_dataset_masha/train", transform=None)
# val_dataset = ImageFolder("split_dataset_masha/val", transform=None)
# test_dataset = ImageFolder("split_dataset_masha/test", transform=None)

Dataset ImageFolder
    Number of datapoints: 298771
    Root location: dataset_masha

In [None]:
# import numpy as np
#
# def print_class_distribution(dataset, name):
#     if isinstance(dataset, Subset):
#         labels = [dataset.dataset.targets[i] for i in dataset.indices]
#     else:
#         labels = dataset.targets
#     unique, counts = np.unique(labels, return_counts=True)
#     print(f"{name} distribution:")
#     for cls, count in zip(unique, counts):
#         print(f"Class {cls}: {count} samples ({count / len(labels):.2%})")
#
# print_class_distribution(train_dataset, "Train")
# print_class_distribution(val_dataset, "Validation")
# print_class_distribution(test_dataset, "Test")