In [None]:
pwd

In [None]:
import os

In [None]:
os.chdir("../")

In [None]:
pwd

In [None]:
import os
import torch

from src.cancer_detection.components.model import vgg16_modified, cancerClassifier
from src.cancer_detection.components.data_module import ImageTransform
from src.cancer_detection.config.configuration import ConfigurationManager
from src.cancer_detection.entity.config_dataclasses import TrainingConfig, InfernceConfig 
from src.cancer_detection.constants import DEVICE
from src.cancer_detection import logger
from PIL import Image
from typing import Type, List, Dict
from pathlib import Path
import gdown
import zipfile


class PredictionPipeline:
    def __init__(self):
        pass

    @staticmethod
    def _predict(
        training_config: TrainingConfig,
        inference_config: InfernceConfig,
        image_transformation_pipeline: Type[ImageTransform],
        filename : Path
    ) -> Dict[str, str]:
        
        # load model
        model = vgg16_modified(training_config)


        if inference_config.load_from_local:
            # load checkpoint locally
            try:
                model_ = cancerClassifier.load_from_checkpoint(inference_config.path_to_best_checkpoint_local, model=model, config=training_config)

            except FileNotFoundError:
                logger.info(f"Error: File '{inference_config.path_to_best_checkpoint_local}' not found.")

            except Exception as e:
                logger.info(f"Error: An unexpected error occurred - {e}")

        else:
            checkpoint_url = str(inference_config.URL_to_load_from_drive)
            out_file = str(inference_config.best_model_checkpoints_saved_from_URL)

            if not inference_config.best_model_checkpoints_saved_from_URL.is_file():
                # download the checkpoint from gdrive
                logger.info(f"Downloading data from {checkpoint_url} into file {out_file}")
                file_id = checkpoint_url.split("/")[-2]
                prefix = 'https://drive.google.com/uc?/export=download&id='
                gdown.download(prefix+file_id, out_file)
                
            model_ = cancerClassifier.load_from_checkpoint(out_file, model=model, config=training_config)

        model_.eval()        
        # load image
        img = Image.open(filename)
        # Transformimg Image
        test_image = image_transformation_pipeline(img, "test").unsqueeze(dim=0).to(DEVICE)

        # make predictions
        logits = model_(test_image)
        preds = torch.argmax(logits, dim = 1)
        if preds[0] == 1:
            prediction = 'Normal'
            return [{ "image" : prediction}]
        else:
            prediction = 'Adenocarcinoma Cancer'
            return [{ "image" : prediction}]

In [None]:
path_to_image =   # Add path to image here to make it work
training_config = ConfigurationManager().get_training_config()
inference_config = ConfigurationManager().get_inference_config()
image_transformation_pipeline = ImageTransform(training_config.params_image_size[0])

preds = PredictionPipeline()
preds._predict(training_config, inference_config, image_transformation_pipeline, path_to_image)

In [None]:
from PIL import Image

In [None]:
img1 = Image.new("RGB", (224, 224), color="green")

In [None]:
img1.save("image3.png", format="PNG")