In [16]:
import os
import pandas as pd
import glob
from PIL import Image
import numpy as np
import cv2


import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
# cv2는 이미지 색상을 BGR로 불러와서, RGB 색상으로 변환해주는 함수
def BGR2RGB(image):
    rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return rgb_img

### Normalization ###

# 컬러 이미지 Normalization 함수
def Normalization_Color(image):
    normalize_img = cv2.normalize(BGR2RGB(image), None, 0, 255, cv2.NORM_MINMAX) # RGB 이미지에 cv2.normalize() 적용
    shape = normalize_img.shape # 제대로 적용됐는지 이미지 shape 확인하기 위한 용도
    return normalize_img, shape

# 그레이 이미지 Normalization 함수
def Normalization_Gray(image):
    img = cv2.cvtColor(BGR2RGB(image), cv2.COLOR_RGB2GRAY) # RGB 색상을 Gray 색상으로 변환
    normalize_img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX) # Gray 이미지에 cv2.normalize() 적용
    shape = normalize_img.shape # 제대로 적용됐는지 이미지 shape 확인하기 위한 용도
    return normalize_img, shape

# 컬러 이미지 HE 함수
def HE_Color(image):
    '''
    컬러 이미지에 HE 적용할 땐 밝기 정보와 색상 정보를 분리한 후, 밝기 정보인 Y에 대해서만 HE 수행함
    색상 정보는 변환되지 않아서 색감이 그대로 유지되고, 밝기 정보인 Y에 대해서만 명암비가 증가하게 됨
    '''
    img = cv2.cvtColor(BGR2RGB(image), cv2.COLOR_RGB2YCrCb) # RGB 색상을 YCrCb 색상으로 변환
    img_planes = cv2.split(img) # Y(밝기 정보), Cr & Cb(색상 정보)로 split
    img_planes_0 = cv2.equalizeHist(img_planes[0]) # 밝기 정보인 Y에 대해서만 cv2.equalizeHist() 적용
    merge_img = cv2.merge([img_planes_0, img_planes[1], img_planes[2]]) # 변환된 Y와 색상 정보 merge
    he_img = cv2.cvtColor(merge_img, cv2.COLOR_YCrCb2RGB) # YCrCb 색상을 RGB 색상으로 변환
    shape = he_img.shape # 제대로 적용됐는지 이미지 shape 확인하기 위한 용도
    return he_img, shape

# 그레이 이미지 HE 함수
def HE_Gray(image):
    img = cv2.cvtColor(BGR2RGB(image), cv2.COLOR_RGB2GRAY) # RGB 색상을 Gray 색상으로 변환
    he_img = cv2.equalizeHist(img) # Gray 이미지에 cv2.equalizeHist() 적용
    shape = he_img.shape # 제대로 적용됐는지 이미지 shape 확인하기 위한 용도
    return he_img, shape

In [19]:
# grayscale None
# class CustomTransform:
#     def __call__(self, img):
#         img = np.array(img)
#         img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
#         #img, _ = HE_Gray(img)
#         #img, _ = Normalization_Gray(img)
#         return Image.fromarray(img)

# grayscale HE
# class CustomTransform:
#     def __call__(self, img):
#         img = np.array(img)
#         #img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
#         img, _ = HE_Gray(img)
#         #img, _ = Normalization_Gray(img)
#         return Image.fromarray(img)

# grayscale Norm
class CustomTransform:
    def __call__(self, img):
        img = np.array(img)
        #img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        #img, _ = HE_Gray(img)
        img, _ = Normalization_Gray(img)
        return Image.fromarray(img)

# grayscale Norm&HE
# class CustomTransform:
#     def __call__(self, img):
#         img = np.array(img)
#         #img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
#         img, _ = HE_Gray(img)
#         img, _ = Normalization_Gray(img)
#         return Image.fromarray(img)


In [20]:
N_WAY = 2
N_SHOT = 2
N_QUERY = 2
N_EVALUATION_TASK = 100

In [21]:
import os
import glob

class CustomDataset(Dataset):
    def __init__(self, base_dir, transform=None):
        normal_img_paths = glob.glob(os.path.join(base_dir, '0', '*.jpg'))
        abnormal_img_paths = glob.glob(os.path.join(base_dir, '1', '*.jpg'))

        self.img_paths = normal_img_paths + abnormal_img_paths
        self.transform = transform

        self.img_labels = [int(os.path.split(os.path.dirname(path))[-1]) for path in self.img_paths]

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path)
        label = self.img_labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


In [22]:
train_set = CustomDataset(
    base_dir='/content/drive/MyDrive/data_all (1)/data_all/train_data',
    transform=transforms.Compose(
        [
            CustomTransform(),
            transforms.Resize((512, 512)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    )
)

In [23]:
N_TRAINING_EPISODES = 10000
N_VALIDATION_TASK = 100

train_set.get_labels = lambda: train_set.img_labels
train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)

train_loader = DataLoader(
    train_set,
    batch_sampler = train_sampler,
    num_workers = 1,
    pin_memory = True,
    collate_fn = train_sampler.episodic_collate_fn,
)

In [24]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


convolutional_network = resnet18(pretrained=True)
# convolutional_network.fc = nn.Flatten()
num_ftrs = convolutional_network.fc.in_features
convolutional_network.fc = nn.Linear(num_ftrs, 2)

convolutional_network.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)


print(convolutional_network)

model = PrototypicalNetworks(convolutional_network).cuda()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images.cuda(), support_labels.cuda(), query_images.cuda()
    )

    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()

In [26]:
log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

100%|██████████| 10000/10000 [40:47<00:00,  4.09it/s, loss=0]


In [27]:
test_set = CustomDataset(
    base_dir='/content/drive/MyDrive/data_all (1)/data_all/test_data',
    transform=transforms.Compose(
        [
            CustomTransform(),
            transforms.Resize((512, 512)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    )
)

N_EVALUATION_TASKS = 100

test_set.get_labels = lambda: test_set.img_labels
test_sampler = TaskSampler(
    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=1,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [28]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
            .detach()
            .data,
            1,
        )[1]
        == query_labels.cuda()
    ).sum().item(), len(query_labels)


def evaluate(data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0

    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            correct, total = evaluate_on_one_task(
                support_images, support_labels, query_images, query_labels
            )

            total_predictions += total
            correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )


evaluate(test_loader)

100%|██████████| 100/100 [00:31<00:00,  3.20it/s]

Model tested on 100 tasks. Accuracy: 99.25%





In [30]:
torch.save(model.state_dict(), '/content/drive/MyDrive/data_all (1)/Few Shot Learning_1_512_Normalization.pth')

In [31]:
# import os
# import glob

# import torch
# import torch.nn as nn
# from torchvision.models import resnet18
# from torchvision import transforms
# from torch.utils.data import Dataset, DataLoader
# from PIL import Image



# class PrototypicalNetworks(nn.Module):
#     def __init__(self, backbone: nn.Module):
#         super(PrototypicalNetworks, self).__init__()
#         self.backbone = backbone

#     def forward(
#         self,
#         support_images: torch.Tensor,
#         support_labels: torch.Tensor,
#         query_images: torch.Tensor,
#     ) -> torch.Tensor:
#         """
#         Predict query labels using labeled support images.
#         """
#         # Extract the features of support and query images
#         z_support = self.backbone.forward(support_images)
#         z_query = self.backbone.forward(query_images)

#         # Infer the number of different classes from the labels of the support set
#         n_way = len(torch.unique(support_labels))
#         # Prototype i is the mean of all instances of features corresponding to labels == i
#         z_proto = torch.cat(
#             [
#                 z_support[torch.nonzero(support_labels == label)].mean(0)
#                 for label in range(n_way)
#             ]
#         )

#         # Compute the euclidean distance from queries to prototypes
#         dists = torch.cdist(z_query, z_proto)

#         # And here is the super complicated operation to transform those distances into classification scores!
#         scores = -dists
#         return scores


# convolutional_network = resnet18(pretrained=True)
# # convolutional_network.fc = nn.Flatten()
# num_ftrs = convolutional_network.fc.in_features
# convolutional_network.fc = nn.Linear(num_ftrs, 2)

# print(convolutional_network)

# model = PrototypicalNetworks(convolutional_network).cuda()
# model.load_state_dict(torch.load('/BTS2023/byeonjun/BTS2023/few_shot/proto/result_pth/few_shot3.pth'))

In [None]:
# class SupportDataset(Dataset):
#     def __init__(self, base_dir, transform=None):
#         self.normal_img_paths = glob.glob(os.path.join(base_dir, '0', '*.jpg'))
#         self.abnormal_img_paths = glob.glob(os.path.join(base_dir, '1', '*.jpg'))

#         self.img_paths = self.normal_img_paths + self.abnormal_img_paths
#         self.img_labels = [0]*len(self.normal_img_paths) + [1]*len(self.abnormal_img_paths)
#         self.transform = transform

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         image = Image.open(img_path)
#         label = self.img_labels[idx]

#         if self.transform:
#             image = self.transform(image)

#         return image, label

# class QueryDataset(Dataset):
#     def __init__(self, base_dir, transform=None):
#         self.img_paths = glob.glob(os.path.join(base_dir, '*.jpg'))
#         self.transform = transform

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         image_name = os.path.basename(img_path)
#         image = Image.open(img_path)

#         if self.transform:
#             image = self.transform(image)

#         return image, image_name


# from collections import Counter

# def predict_with_support(support_loader, query_loader):
#     all_predictions = []

#     for support_images, support_labels in support_loader:
#         support_images = support_images.cuda()
#         support_labels = support_labels.cuda()

#         with torch.no_grad():
#             for query_images, image_names in query_loader:
#                 temp_preds = []

#                 # 각 이미지에 대한 예측을 5번 수행
#                 for _ in range(5):
#                     outputs = model(support_images, support_labels, query_images.cuda())
#                     _, preds = torch.max(outputs, 1)
#                     temp_preds.append(preds)

#                 # 3번의 예측 중 가장 많이 예측된 값을 최종 예측값으로 선택
#                 for i, img_name in enumerate(image_names):
#                     counter = Counter([temp[i].item() for temp in temp_preds])
#                     most_common_pred = counter.most_common(1)[0][0]
#                     all_predictions.append((img_name, most_common_pred))

#     return all_predictions



In [None]:
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomVerticalFlip(),
#     transforms.ToTensor(),
# ])

# support_set = SupportDataset(base_dir='/BTS2023/byeonjun/BTS2023/few_shot/proto/train_data', transform=transform)
# query_set = QueryDataset(base_dir='/BTS2023/byeonjun/BTS2023/few_shot/proto/test_all', transform=transform)

# support_loader = DataLoader(support_set, batch_size=8, shuffle=True)
# query_loader = DataLoader(query_set, batch_size=8, shuffle=False)

# predictions = predict_with_support(support_loader, query_loader)
# for file_path, pred in predictions:
#     print(f"File: {file_path}, Prediction: {pred}")