In [None]:
from params import params
import requests
from fine_tuning import fine_tuning_beit as fn
from image_collector import image_collector as ic
import torch
import os
from torch.optim import Adam
from torch.nn import CrossEntropyLoss 
from transformers import BeitConfig, BeitImageProcessor, BeitForImageClassification
# from srcs.image_collector.image_collector import load_image_collection, RGB_convert
# from make_graph import make_graph
from matplotlib import pyplot as plt

In [None]:
#------------------------ Download collections ---------------------------#
# os.makedirs(params.cats_val_path, exis_ok=True)
# cats_url = fn.fetch_cat_images()

# # Download and save images locally
# for idx, url in enumerate(cats_url):
#     img_data = requests.get(url).content
#     with open(f"cats_and_dogs_images/cat_{idx + 1}.jpg", 'wb') as file:
#         file.write(img_data)



In [None]:
#------------------------ Loading collection -----------------------------#
cat_collection = ic.RGB_convert(ic.load_image_collection(params.cat_path, 10, ".jpg"))
dog_collection = ic.RGB_convert(ic.load_image_collection(params.dog_path, 10, ".jpg"))
dog_val_collection = ic.RGB_convert(ic.load_image_collection(params.dog_val_path, 10, ".jpg"))
cat_val_collection = ic.RGB_convert(ic.load_image_collection(params.cat_val_path, 10, ".jpg"))
val_collection = cat_val_collection + dog_val_collection
all_images = cat_collection + dog_collection

In [None]:
#-------------- Importing, initilazing and setting model and processor ---#
# Load processor
processor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224')
# Load the pre-trained model configuration
config = BeitConfig.from_pretrained('microsoft/beit-base-patch16-224')
# Set number of classes to 2 (cats and dogs)
config.num_labels = 2
# Initialize the model with the updated config
model = BeitForImageClassification(config)



In [None]:
#-------------- Saving model architecture -------------------------------#
os.makedirs(params.save_dir, exist_ok=True)  # Create the directory if it doesn't exist
os.makedirs(params.save_archi, exist_ok=True)  # Create the directory if it doesn't exist
config.save_pretrained(params.save_archi)
print("Model architecture saved!")

In [None]:
#-------------- Loading model on device (cpu) ----------------------------#
device = torch.device("cpu")
model.to(device)

In [None]:
#-------------- Defining collection and labels ---------------------------#
collection =  cat_collection + dog_collection
labels = ([0] * len(cat_collection)) + ([1] * len(dog_collection))

In [None]:
#-------------- Converting images and labels into tensors ----------------#
batch = processor(images=collection, return_tensors="pt")
labels = torch.tensor(labels)

In [None]:
#-------------- loading tensors on device (cpu) --------------------------#
# inputs = {k: v.to(device) for k, v in batch["pixel_values"].items()}
inputs = batch["pixel_values"].to(device)
labels = labels.to(device)

In [None]:
#-------------- Training -----------------------------------------------#
trained_model = fn.train_and_eval_model(model=model, processor=processor, inputs=inputs, labels=labels, epochs=params.epochs, lr=params.learning_rate, val_collection = val_collection)

In [None]:
#-------------- Evaluation -----------------------------------------------#
# outputs = fn.eval_model(trained_model, processor, "cpu", val_collection)
# success_rate = fn.success_rate(len(val_collection), outputs)
# print(f"success rate is {success_rate} %")

In [None]:
#-------------- Plotting -------------------------------------------------#
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(trained_model["success_rate"]) + 1), trained_model["success_rate"])
plt.scatter(range(1, len(trained_model["success_rate"]) + 1), trained_model["success_rate"], s=10)
for i, value in enumerate(trained_model["success_rate"]):
    plt.text(i + 1, value + 0.1, round(trained_model["losses"][i], 2), ha='center', va='bottom', fontsize=10)  # Adjust `round(value, 2)` if you want to format the numbers
plt.savefig(params.stat_path + "/output_res")
plt.show()