In [None]:
'''
*Important!!!*
If you are running my inference code on Colab, please make sure that the data has the following route:
The route to training dataset is /content/drive/MyDrive/data/train
The route to test dataset is /content/drive/MyDrive/data/test
The route to the trained model weight is /content/drive/MyDrive/trained_model_fin.pth
I am sure that it could be executed once the route is correct, the route is determined when uploading them to the drive. If the dataset and the weight are uploaded directly to the drive, the route should be the same.
'''

In [3]:
!pip install torch
!pip install scikit-learn
!pip install pandas
!pip install tqdm
!pip install torchvision
!pip install numpy



In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
from sklearn.model_selection import train_test_split
import random

# Mount Google Drive to access the dataset and save the results
from google.colab import drive
drive.mount('/content/drive')

# Set a seed for PyTorch
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Set a seed for NumPy
np.random.seed(42)

# Set a seed for the random module
random.seed(42)

# Define transform for test data
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define transform to preprocess the images
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

class BirdDataset(Dataset):
    def __init__(self, folder_path, transform=None, is_test=False):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        self.is_test = is_test

        self._load_data()

    def _load_data(self):
        if self.is_test:
            self.image_paths = [os.path.join(self.folder_path, img) for img in os.listdir(self.folder_path)]
        else:
            # Sort classes alphabetically
            classes = os.listdir(self.folder_path)
            for i, class_name in enumerate(classes):
                class_path = os.path.join(self.folder_path, class_name)
                if os.path.isdir(class_path):
                    self.class_to_idx[class_name] = i
                    self.idx_to_class[i] = class_name
                    for image_name in os.listdir(class_path):
                        image_path = os.path.join(class_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(i)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if self.is_test:
            image_name = os.path.basename(image_path)
            return image, image_name
        else:
            label = self.labels[idx]
            return image, label, idx

# Load the datasets
all_dataset = BirdDataset('/content/drive/MyDrive/data/train', transform=data_transform, is_test=False)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(all_dataset))
val_size = len(all_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])

# Load the test dataset
test_dataset = BirdDataset('/content/drive/MyDrive/data/test', transform=data_transform, is_test=True)

# Create data loader
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Create data loaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Load the ResNet-50 model with local weights
pretrained_model = models.resnet50(pretrained=True)
pretrained_model.fc = torch.nn.Linear(2048, 200)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model = pretrained_model.to(device)

# Load the trained weights
trained_weights_path = '/content/drive/MyDrive/trained_model_fin.pth'
pretrained_model.load_state_dict(torch.load(trained_weights_path))
pretrained_model.eval()

# Inference loop
predictions_test = []
image_names_test = []

with torch.no_grad():
    for inputs, image_names in test_loader:
        inputs = inputs.to(device)
        outputs = pretrained_model(inputs)
        _, predicted = torch.max(outputs, 1)
        predictions_test.extend(predicted.cpu().numpy())
        image_names_test.extend(image_names)

# Altering index to class because the mapping of kaggle and colab are different.

all_dataset.idx_to_class[0] = '037.Acadian_Flycatcher'
all_dataset.idx_to_class[1] = '145.Elegant_Tern'
all_dataset.idx_to_class[2] = '115.Brewer_Sparrow'
all_dataset.idx_to_class[3] = '125.Lincoln_Sparrow'
all_dataset.idx_to_class[4] = '063.Ivory_Gull'
all_dataset.idx_to_class[5] = '073.Blue_Jay'
all_dataset.idx_to_class[6] = '116.Chipping_Sparrow'
all_dataset.idx_to_class[7] = '135.Bank_Swallow'
all_dataset.idx_to_class[8] = '187.American_Three_toed_Woodpecker'
all_dataset.idx_to_class[9] = '174.Palm_Warbler'
all_dataset.idx_to_class[10] = '014.Indigo_Bunting'
all_dataset.idx_to_class[11] = '061.Heermann_Gull'
all_dataset.idx_to_class[12] = '092.Nighthawk'
all_dataset.idx_to_class[13] = '129.Song_Sparrow'
all_dataset.idx_to_class[14] = '022.Chuck_will_Widow'
all_dataset.idx_to_class[15] = '036.Northern_Flicker'
all_dataset.idx_to_class[16] = '169.Magnolia_Warbler'
all_dataset.idx_to_class[17] = '127.Savannah_Sparrow'
all_dataset.idx_to_class[18] = '102.Western_Wood_Pewee'
all_dataset.idx_to_class[19] = '162.Canada_Warbler'
all_dataset.idx_to_class[20] = '120.Fox_Sparrow'
all_dataset.idx_to_class[21] = '003.Sooty_Albatross'
all_dataset.idx_to_class[22] = '076.Dark_eyed_Junco'
all_dataset.idx_to_class[23] = '131.Vesper_Sparrow'
all_dataset.idx_to_class[24] = '094.White_breasted_Nuthatch'
all_dataset.idx_to_class[25] = '128.Seaside_Sparrow'
all_dataset.idx_to_class[26] = '083.White_breasted_Kingfisher'
all_dataset.idx_to_class[27] = '033.Yellow_billed_Cuckoo'
all_dataset.idx_to_class[28] = '112.Great_Grey_Shrike'
all_dataset.idx_to_class[29] = '072.Pomarine_Jaeger'
all_dataset.idx_to_class[30] = '182.Yellow_Warbler'
all_dataset.idx_to_class[31] = '160.Black_throated_Blue_Warbler'
all_dataset.idx_to_class[32] = '108.White_necked_Raven'
all_dataset.idx_to_class[33] = '064.Ring_billed_Gull'
all_dataset.idx_to_class[34] = '170.Mourning_Warbler'
all_dataset.idx_to_class[35] = '042.Vermilion_Flycatcher'
all_dataset.idx_to_class[36] = '171.Myrtle_Warbler'
all_dataset.idx_to_class[37] = '081.Pied_Kingfisher'
all_dataset.idx_to_class[38] = '098.Scott_Oriole'
all_dataset.idx_to_class[39] = '168.Kentucky_Warbler'
all_dataset.idx_to_class[40] = '164.Cerulean_Warbler'
all_dataset.idx_to_class[41] = '051.Horned_Grebe'
all_dataset.idx_to_class[42] = '030.Fish_Crow'
all_dataset.idx_to_class[43] = '031.Black_billed_Cuckoo'
all_dataset.idx_to_class[44] = '035.Purple_Finch'
all_dataset.idx_to_class[45] = '181.Worm_eating_Warbler'
all_dataset.idx_to_class[46] = '195.Carolina_Wren'
all_dataset.idx_to_class[47] = '008.Rhinoceros_Auklet'
all_dataset.idx_to_class[48] = '089.Hooded_Merganser'
all_dataset.idx_to_class[49] = '021.Eastern_Towhee'
all_dataset.idx_to_class[50] = '189.Red_bellied_Woodpecker'
all_dataset.idx_to_class[51] = '186.Cedar_Waxwing'
all_dataset.idx_to_class[52] = '158.Bay_breasted_Warbler'
all_dataset.idx_to_class[53] = '121.Grasshopper_Sparrow'
all_dataset.idx_to_class[54] = '130.Tree_Sparrow'
all_dataset.idx_to_class[55] = '156.White_eyed_Vireo'
all_dataset.idx_to_class[56] = '078.Gray_Kingbird'
all_dataset.idx_to_class[57] = '047.American_Goldfinch'
all_dataset.idx_to_class[58] = '079.Belted_Kingfisher'
all_dataset.idx_to_class[59] = '071.Long_tailed_Jaeger'
all_dataset.idx_to_class[60] = '095.Baltimore_Oriole'
all_dataset.idx_to_class[61] = '137.Cliff_Swallow'
all_dataset.idx_to_class[62] = '139.Scarlet_Tanager'
all_dataset.idx_to_class[63] = '196.House_Wren'
all_dataset.idx_to_class[64] = '192.Downy_Woodpecker'
all_dataset.idx_to_class[65] = '062.Herring_Gull'
all_dataset.idx_to_class[66] = '067.Anna_Hummingbird'
all_dataset.idx_to_class[67] = '143.Caspian_Tern'
all_dataset.idx_to_class[68] = '024.Red_faced_Cormorant'
all_dataset.idx_to_class[69] = '013.Bobolink'
all_dataset.idx_to_class[70] = '109.American_Redstart'
all_dataset.idx_to_class[71] = '107.Common_Raven'
all_dataset.idx_to_class[72] = '183.Northern_Waterthrush'
all_dataset.idx_to_class[73] = '056.Pine_Grosbeak'
all_dataset.idx_to_class[74] = '045.Northern_Fulmar'
all_dataset.idx_to_class[75] = '007.Parakeet_Auklet'
all_dataset.idx_to_class[76] = '017.Cardinal'
all_dataset.idx_to_class[77] = '124.Le_Conte_Sparrow'
all_dataset.idx_to_class[78] = '066.Western_Gull'
all_dataset.idx_to_class[79] = '068.Ruby_throated_Hummingbird'
all_dataset.idx_to_class[80] = '193.Bewick_Wren'
all_dataset.idx_to_class[81] = '088.Western_Meadowlark'
all_dataset.idx_to_class[82] = '172.Nashville_Warbler'
all_dataset.idx_to_class[83] = '157.Yellow_throated_Vireo'
all_dataset.idx_to_class[84] = '159.Black_and_white_Warbler'
all_dataset.idx_to_class[85] = '048.European_Goldfinch'
all_dataset.idx_to_class[86] = '004.Groove_billed_Ani'
all_dataset.idx_to_class[87] = '110.Geococcyx'
all_dataset.idx_to_class[88] = '020.Yellow_breasted_Chat'
all_dataset.idx_to_class[89] = '132.White_crowned_Sparrow'
all_dataset.idx_to_class[90] = '119.Field_Sparrow'
all_dataset.idx_to_class[91] = '034.Gray_crowned_Rosy_Finch'
all_dataset.idx_to_class[92] = '016.Painted_Bunting'
all_dataset.idx_to_class[93] = '044.Frigatebird'
all_dataset.idx_to_class[94] = '104.American_Pipit'
all_dataset.idx_to_class[95] = '111.Loggerhead_Shrike'
all_dataset.idx_to_class[96] = '002.Laysan_Albatross'
all_dataset.idx_to_class[97] = '114.Black_throated_Sparrow'
all_dataset.idx_to_class[98] = '080.Green_Kingfisher'
all_dataset.idx_to_class[99] = '093.Clark_Nutcracker'
all_dataset.idx_to_class[100] = '075.Green_Jay'
all_dataset.idx_to_class[101] = '144.Common_Tern'
all_dataset.idx_to_class[102] = '050.Eared_Grebe'
all_dataset.idx_to_class[103] = '166.Golden_winged_Warbler'
all_dataset.idx_to_class[104] = '060.Glaucous_winged_Gull'
all_dataset.idx_to_class[105] = '197.Marsh_Wren'
all_dataset.idx_to_class[106] = '009.Brewer_Blackbird'
all_dataset.idx_to_class[107] = '178.Swainson_Warbler'
all_dataset.idx_to_class[108] = '163.Cape_May_Warbler'
all_dataset.idx_to_class[109] = '173.Orange_crowned_Warbler'
all_dataset.idx_to_class[110] = '049.Boat_tailed_Grackle'
all_dataset.idx_to_class[111] = '011.Rusty_Blackbird'
all_dataset.idx_to_class[112] = '165.Chestnut_sided_Warbler'
all_dataset.idx_to_class[113] = '155.Warbling_Vireo'
all_dataset.idx_to_class[114] = '058.Pigeon_Guillemot'
all_dataset.idx_to_class[115] = '006.Least_Auklet'
all_dataset.idx_to_class[116] = '180.Wilson_Warbler'
all_dataset.idx_to_class[117] = '043.Yellow_bellied_Flycatcher'
all_dataset.idx_to_class[118] = '052.Pied_billed_Grebe'
all_dataset.idx_to_class[119] = '140.Summer_Tanager'
all_dataset.idx_to_class[120] = '054.Blue_Grosbeak'
all_dataset.idx_to_class[121] = '070.Green_Violetear'
all_dataset.idx_to_class[122] = '141.Artic_Tern'
all_dataset.idx_to_class[123] = '025.Pelagic_Cormorant'
all_dataset.idx_to_class[124] = '194.Cactus_Wren'
all_dataset.idx_to_class[125] = '057.Rose_breasted_Grosbeak'
all_dataset.idx_to_class[126] = '133.White_throated_Sparrow'
all_dataset.idx_to_class[127] = '100.Brown_Pelican'
all_dataset.idx_to_class[128] = '085.Horned_Lark'
all_dataset.idx_to_class[129] = '153.Philadelphia_Vireo'
all_dataset.idx_to_class[130] = '122.Harris_Sparrow'
all_dataset.idx_to_class[131] = '123.Henslow_Sparrow'
all_dataset.idx_to_class[132] = '027.Shiny_Cowbird'
all_dataset.idx_to_class[133] = '190.Red_cockaded_Woodpecker'
all_dataset.idx_to_class[134] = '154.Red_eyed_Vireo'
all_dataset.idx_to_class[135] = '015.Lazuli_Bunting'
all_dataset.idx_to_class[136] = '152.Blue_headed_Vireo'
all_dataset.idx_to_class[137] = '101.White_Pelican'
all_dataset.idx_to_class[138] = '147.Least_Tern'
all_dataset.idx_to_class[139] = '118.House_Sparrow'
all_dataset.idx_to_class[140] = '012.Yellow_headed_Blackbird'
all_dataset.idx_to_class[141] = '149.Brown_Thrasher'
all_dataset.idx_to_class[142] = '200.Common_Yellowthroat'
all_dataset.idx_to_class[143] = '055.Evening_Grosbeak'
all_dataset.idx_to_class[144] = '059.California_Gull'
all_dataset.idx_to_class[145] = '146.Forsters_Tern'
all_dataset.idx_to_class[146] = '041.Scissor_tailed_Flycatcher'
all_dataset.idx_to_class[147] = '161.Blue_winged_Warbler'
all_dataset.idx_to_class[148] = '185.Bohemian_Waxwing'
all_dataset.idx_to_class[149] = '126.Nelson_Sharp_tailed_Sparrow'
all_dataset.idx_to_class[150] = '087.Mallard'
all_dataset.idx_to_class[151] = '191.Red_headed_Woodpecker'
all_dataset.idx_to_class[152] = '029.American_Crow'
all_dataset.idx_to_class[153] = '184.Louisiana_Waterthrush'
all_dataset.idx_to_class[154] = '176.Prairie_Warbler'
all_dataset.idx_to_class[155] = '096.Hooded_Oriole'
all_dataset.idx_to_class[156] = '086.Pacific_Loon'
all_dataset.idx_to_class[157] = '039.Least_Flycatcher'
all_dataset.idx_to_class[158] = '117.Clay_colored_Sparrow'
all_dataset.idx_to_class[159] = '113.Baird_Sparrow'
all_dataset.idx_to_class[160] = '038.Great_Crested_Flycatcher'
all_dataset.idx_to_class[161] = '032.Mangrove_Cuckoo'
all_dataset.idx_to_class[162] = '069.Rufous_Hummingbird'
all_dataset.idx_to_class[163] = '091.Mockingbird'
all_dataset.idx_to_class[164] = '198.Rock_Wren'
all_dataset.idx_to_class[165] = '077.Tropical_Kingbird'
all_dataset.idx_to_class[166] = '199.Winter_Wren'
all_dataset.idx_to_class[167] = '053.Western_Grebe'
all_dataset.idx_to_class[168] = '090.Red_breasted_Merganser'
all_dataset.idx_to_class[169] = '097.Orchard_Oriole'
all_dataset.idx_to_class[170] = '074.Florida_Jay'
all_dataset.idx_to_class[171] = '142.Black_Tern'
all_dataset.idx_to_class[172] = '177.Prothonotary_Warbler'
all_dataset.idx_to_class[173] = '105.Whip_poor_Will'
all_dataset.idx_to_class[174] = '099.Ovenbird'
all_dataset.idx_to_class[175] = '136.Barn_Swallow'
all_dataset.idx_to_class[176] = '082.Ringed_Kingfisher'
all_dataset.idx_to_class[177] = '046.Gadwall'
all_dataset.idx_to_class[178] = '179.Tennessee_Warbler'
all_dataset.idx_to_class[179] = '026.Bronzed_Cowbird'
all_dataset.idx_to_class[180] = '138.Tree_Swallow'
all_dataset.idx_to_class[181] = '023.Brandt_Cormorant'
all_dataset.idx_to_class[182] = '167.Hooded_Warbler'
all_dataset.idx_to_class[183] = '018.Spotted_Catbird'
all_dataset.idx_to_class[184] = '001.Black_footed_Albatross'
all_dataset.idx_to_class[185] = '175.Pine_Warbler'
all_dataset.idx_to_class[186] = '010.Red_winged_Blackbird'
all_dataset.idx_to_class[187] = '103.Sayornis'
all_dataset.idx_to_class[188] = '134.Cape_Glossy_Starling'
all_dataset.idx_to_class[189] = '084.Red_legged_Kittiwake'
all_dataset.idx_to_class[190] = '028.Brown_Creeper'
all_dataset.idx_to_class[191] = '040.Olive_sided_Flycatcher'
all_dataset.idx_to_class[192] = '148.Green_tailed_Towhee'
all_dataset.idx_to_class[193] = '065.Slaty_backed_Gull'
all_dataset.idx_to_class[194] = '151.Black_capped_Vireo'
all_dataset.idx_to_class[195] = '005.Crested_Auklet'
all_dataset.idx_to_class[196] = '019.Gray_Catbird'
all_dataset.idx_to_class[197] = '106.Horned_Puffin'
all_dataset.idx_to_class[198] = '150.Sage_Thrasher'
all_dataset.idx_to_class[199] = '188.Pileated_Woodpecker'



'''
# Create a DataFrame to store predictions
test_predictions_df = pd.DataFrame({'id': image_names_test, 'label': predictions_test})
test_predictions_df['id'] = test_predictions_df['id'].apply(lambda x: os.path.splitext(x)[0])
'''
'''
# Save the predictions to CSV
test_predictions_df.to_csv('test_predictions_kl.csv', index=False)
'''


# Create a DataFrame to store predictions
test_predictions_df = pd.DataFrame({'id': image_names_test, 'label': predictions_test})
test_predictions_df['id'] = test_predictions_df['id'].apply(lambda x: os.path.splitext(x)[0])
test_predictions_df['label'] = test_predictions_df['label'].apply(lambda x: f'{all_dataset.idx_to_class[x]}')


# Save the predictions to CSV
test_predictions_df.to_csv('test_predictions_1.csv', index=False)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


