In [2]:
import os
import numpy as np
import torch

# Addresses

class Address:
    def __init__(self):
        '''
        Stores all the addresses used in project
        '''
        # Inputs
        self.data = "../input/mammography"
        self.processed_data = "data.pkl"

        # Coco
        self.coco = os.path.join(self.data, 'coco_1k')
        self.coco_annot = os.path.join(self.coco, 'annotations')
        self.coco_annot_train = os.path.join(self.coco_annot, 'instances_train2017.json')
        self.coco_annot_val = os.path.join(self.coco_annot, 'instances_val2017.json')
        self.coco_img_train = os.path.join(self.coco, 'train2017')
        self.coco_img_val = os.path.join(self.coco, 'val2017')

        # Test
        self.test = os.path.join(self.data, 'test')
        self.test_img = os.path.join(self.test, 'images')
        self.test_label = os.path.join(self.test, 'labels')
        self.predictions = os.path.join(self.test, 'predictions')

        # Yolo
        self.yolo = os.path.join(self.data, 'yolo_1k')
        self.yolo_train = os.path.join(self.yolo, 'train')
        self.yolo_train_img = os.path.join(self.yolo_train, 'images')
        self.yolo_train_label = os.path.join(self.yolo_train, 'labels')
        self.yolo_val = os.path.join(self.yolo, 'val')
        self.yolo_val_img = os.path.join(self.yolo_val, 'images')
        self.yolo_val_labels = os.path.join(self.yolo_val, 'labels')

        # Models
        self.result = "results/"
        self.model_frcnn = os.path.join(self.result, 'frcnn')
        self.model_def_detr = os.path.join(self.result, 'deformable_dtr')
        self.model_detr = os.path.join(self.result, 'dtr')

        # Temp
        self.temp = "temp/"

    def create_dir(self, dir_list = None):
        '''
        Function to create directories in dir_list. If dir_list is None then create all directories of address.
        '''
        if dir_list == None:
            dir_list = [self.temp, self.result, self.model_frcnn, self.model_def_detr, self.model_detr]
        for address in dir_list:
            if not os.path.exists(address):
                os.mkdir(address)

    def _delete_folder_content(self, folder_addr):
        '''
        Deletes all the content of folder_addr
        '''
        if os.path.exists(folder_addr):
            for file in os.listdir(folder_addr):
                address = os.path.join(folder_addr, file)
                if os.path.isdir(address):
                    self._delete_folder_content(address)
                    os.removedirs(address)
                else:
                    os.remove(address)

    def clean(self, file_list = None):
        '''
        Deletes all the content in file_list
        '''
        if file_list == None:
            file_list = [self.temp]
        for address in file_list:
            self._delete_folder_content(address)

addr = Address()
addr.clean()
# addr.clean([addr.model_detr])
addr.create_dir()

class HyperParameters:
    def __init__(self):
        '''
        Stores all Hyperparameters used for training of model
        '''
        # Training
        self.batch_size = 2
        self.num_epoch = 50
        self.grad_clip = 0.1

        # Data
        self.num_train = 2240
        self.num_val = 218
        self.train_step = self.num_epoch*(self.num_train//self.batch_size)
        self.resolution = (256, 512)
        
        # Learning Rate
        self.lr = 3e-5
        self.backbone_lr = 1e-5
        self.weight_decay = 3e-5

    def create_report(self, addr):
        with open(os.path.join(addr, 'param.txt'), 'w') as file:
            file.writelines([
                f'Training:',
                f'\n\tBatch Size:       {self.batch_size}',
                f'\n\tNum Epoch:        {self.num_epoch}',
                f'\n\tGrad Clip:        {self.grad_clip}',
                f'\n\nData:',  
                f'\n\tNum Train:        {self.num_train}',
                f'\n\tNum Val:          {self.num_val}',
                f'\n\tTrain Step:       {self.train_step}',
                f'\n\tResolution:       {self.resolution}',
                f'\n\nLearning Rate:',  
                f'\n\tlr:               {self.lr}',
                f'\n\tBackbone lr:      {self.backbone_lr}',
                f'\n\tWeight Decay:     {self.weight_decay}',
            ])

param = HyperParameters()

# Random Seed and CUDA

random_seed = 68
device = "cpu"
torch.manual_seed(random_seed)
np.random.seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
    device = "cuda"
print(f"Working with device {device}")

Working with device cuda


In [3]:
import PIL.Image
import torch.utils.data
import torchvision
import transformers
import PIL.ImageEnhance, PIL.ImageFilter
import json

# !pip install pycocotools

model_checkpoint = "facebook/detr-resnet-50"

class Dataset(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, annotation_file, processor):
        super(Dataset, self).__init__(img_folder, annotation_file,)

        self.processor = processor

    def __getitem__(self, idx):
        img, target = super(Dataset, self).__getitem__(idx)

        img = self.transform_fn(img)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        encoding["pixel_values"] = encoding["pixel_values"].squeeze() # remove batch dimension
        encoding["labels"] = encoding["labels"][0] # remove batch dimension

        return encoding
    
    def transform_fn(self, img):
        '''
        Preprocessing
        '''
        contrast = PIL.ImageEnhance.Contrast(img)
        contrast.enhance(3.0)         # Increasing Contrast
        image = contrast.image.filter(PIL.ImageFilter.GaussianBlur(radius=5))       # Gaussian Blur
        return image

class EvalDataSet():
    def __init__(self, address_img, address_annot):
        '''
        Creates Dataset of images on given address.
        '''
        self.address_img = address_img
        self.temp_annotation = None
        with open(address_annot, 'rb') as file:
            self.annotation = json.load(file)
        self.beautify()

    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, idx):
        if 'image' not in self.annotation[idx]:
            img_addr = os.path.join(self.address_img, self.annotation[idx]['file_name'])
            img = PIL.Image.open(img_addr)
            self.annotation[idx]['image'] = img
        return self.annotation[idx]
    
    def beautify(self):
        img_info = self.annotation['images']
        annotations = self.annotation['annotations']
        id_dict = {}
        for img in img_info:
            img['objects'] = []
            id_dict[img['id']] = img
        for annot in annotations:
            img = id_dict[annot['image_id']]
            img['objects'].append(annot)
        self.annotation = id_dict

class Data():
    def __init__(self):
        self.processor = transformers.DetrImageProcessor.from_pretrained(model_checkpoint)

        self.dataset_train = Dataset(addr.coco_img_train, addr.coco_annot_train, self.processor)
        self.dataset_val = Dataset(addr.coco_img_val, addr.coco_annot_val, self.processor)

        self.img_train = EvalDataSet(addr.coco_img_train, addr.coco_annot_train)
        self.img_val = EvalDataSet(addr.coco_img_val, addr.coco_annot_val)

        self.loader_train = torch.utils.data.DataLoader(self.dataset_train, param.batch_size, collate_fn=self.collate_fn, shuffle=True)
        self.loader_val = torch.utils.data.DataLoader(self.dataset_val, param.batch_size, collate_fn=self.collate_fn, shuffle=False)

        self.id2label = {0: 'mal'}
        self.label2id = {'mal': 0}
    
    def collate_fn(self, batch):
        pixel_values = [item['pixel_values'] for item in batch]
        encoding = self.processor.pad(pixel_values, return_tensors="pt")
        encoding['labels'] = [item['labels'] for item in batch]
        return encoding

data = Data()

loading annotations into memory...
Done (t=0.11s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [5]:
import time
import matplotlib.pyplot as plt

# Learning Model
class LearnModel:
    def __init__(self, model: torch.nn.Module, model_addr, data=data, param=param, device=device):
        '''
        Train, Evaluate and Predict
        '''
        
        self.data = data
        self.param = param
        self.device = device
        self.model = model.to(device)
        self.model_addr = model_addr

        # Addresses
        self.loss_addr = os.path.join(self.model_addr, 'loss.npz')
        self.epoch_addr = lambda epoch: os.path.join(self.model_addr, f'model/{epoch}.pth')
        addr.create_dir([os.path.join(self.model_addr, 'model')])

    def train(self, epoch_log = True, batch_log = True, overwrite = False):
        param_dicts = [
              {"params": [p for n, p in self.model.named_parameters() if "backbone" not in n and p.requires_grad]},
              {
                  "params": [p for n, p in self.model.named_parameters() if "backbone" in n and p.requires_grad],
                  "lr": param.backbone_lr,
              },
        ]
        optimizer = torch.optim.AdamW(param_dicts, lr=param.lr, weight_decay=param.weight_decay)
        start_time = time.time()

        # Loss arr
        if os.path.exists(self.loss_addr):
            loss_arr = np.load(self.loss_addr)
            train_loss_arr = list(loss_arr['train'])
            val_loss_arr = list(loss_arr['val'])
        else:
            train_loss_arr = []
            val_loss_arr = []

        if overwrite:
            addr.clean([self.model_addr])
        
        for epoch in range(self.param.num_epoch):
            epoch_addr = self.epoch_addr(epoch)

            # Loading Model if present
            if os.path.exists(epoch_addr):
                self.model.load_state_dict(torch.load(epoch_addr), strict=False)
                print(f"Loaded model and scheduler at epoch {epoch}")
                continue

            # Training Model
            train_loss = self.train_epoch(optimizer, batch_log=batch_log)
            if epoch_log:
                print(f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTime: {time.time()-start_time}')

            # Validating Model
            val_loss = self.validate_epoch(self.data.loader_val, batch_log=False)
            if epoch_log:
                print(f'Epoch: {epoch}\tVal Loss: {val_loss}\tTime: {time.time()-start_time}')

            # Saving data
            train_loss_arr.append(train_loss)
            val_loss_arr.append(val_loss)
            np.savez_compressed(self.loss_addr, train=np.array(train_loss_arr), val=np.array(val_loss_arr))     # Saving Loss Array
            torch.save(self.model.state_dict(), epoch_addr)     # Saving Model

            # Printing blank line between each epoch in Log
            if epoch_log:
                print()
        
    def train_epoch(self, optimizer, batch_log):
        '''
        Trains model for one epoch
        '''
        epoch_loss = 0
        batch_ct = 0
        dataloader = self.data.loader_train
        start_time = time.time()

        self.model.train()          # Set Model to train Mode

        for batch in dataloader:
            # Copying data to cuda
            pixel_values = batch["pixel_values"].to(self.device)
            pixel_mask = batch["pixel_mask"].to(self.device)
            labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

            # Forward Propagation
            outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

            # Computing Loss
            loss = outputs.loss
            epoch_loss += loss.item()

            # Back Propagation
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.param.grad_clip)
            optimizer.step()

            # Update batch count
            batch_ct += 1

            if batch_log and batch_ct%25 == 0:
                print(f"\tBatch {batch_ct}\tLoss: {epoch_loss/batch_ct}\tTime: {time.time()-start_time}")

        return epoch_loss/batch_ct
    
    def validate_epoch(self, dataloader, batch_log):
        '''
        Calculates Loss on data in given dataloader
        '''
        epoch_loss = 0
        batch_ct = 0
        start_time = time.time()

        self.model.eval()           # Set Model to eval mode

        with torch.no_grad():
            for batch in dataloader:
                # Copying data to cuda
                pixel_values = batch["pixel_values"].to(self.device)
                pixel_mask = batch["pixel_mask"].to(self.device)
                labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

                # Forward Propagation
                outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

                # Computing Loss
                loss = outputs.loss
                epoch_loss += loss.item()

                # Update batch count
                batch_ct += 1

                if batch_log and batch_ct%25 == 0:
                    print(f"\tBatch {batch_ct}\tLoss: {epoch_loss/batch_ct}\tTime: {time.time()-start_time}")

        return epoch_loss/batch_ct

    def infer(self):
        iou_threshold = [0, 0.1, 0.25, 0.5, 1]
        nms_addr = os.path.join(addr.model_def_detr, 'nms')
        addr.create_dir([nms_addr])

        def plot_results(pil_img, boxes, orig_boxes, file_addr):
            plt.figure(figsize=(5, 10))
            plt.imshow(pil_img)
            ax = plt.gca()
            for (xmin, ymin, xmax, ymax) in boxes.tolist():
                ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                        fill=False, color='r', linewidth=1))
            for (xmin, ymin, xwidth, ywidth) in orig_boxes:
                ax.add_patch(plt.Rectangle((xmin, ymin), xwidth, ywidth,
                                        fill=False, color='b', linewidth=1))
            plt.axis('off')
            plt.savefig(file_addr)
            plt.close()

        def convert_to_xywh(boxes):
            xmin, ymin, xmax, ymax = boxes.unbind(1)
            return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)

        def prepare_for_coco_detection(predictions):
            coco_results = []
            for original_id, prediction in predictions.items():
                if len(prediction) == 0:
                    continue

                boxes = prediction["boxes"]
                boxes = convert_to_xywh(boxes).tolist()
                scores = prediction["scores"].tolist()
                labels = prediction["labels"].tolist()

                coco_result_single = [{
                    "image_id": original_id,
                    "category_id": labels[k],
                    "bbox": box,
                    "score": scores[k],
                } for k, box in enumerate(boxes)]

                coco_results.append(coco_result_single)

            return coco_results

        self.model.eval()
        nms_addr = os.path.join(self.model_addr, 'nms')
        pred_addr = os.path.join(self.model_addr, 'pred')
        addr.create_dir([nms_addr, pred_addr])

        with torch.no_grad():
            for batch in data.loader_val:
                # Copying data to cuda
                pixel_values = batch["pixel_values"].to(self.device)
                pixel_mask = batch["pixel_mask"].to(self.device)
                labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

                # Forward Propagation
                outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
                target_sizes = torch.stack([target['orig_size'] for target in labels], dim = 0)

                # Storing Predictions
                results = data.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)
                predictions = {target['image_id'].item(): output for target, output in zip(labels, results)}
                predictions = prepare_for_coco_detection(predictions)

                # Plotting with nms
                for ind in range(len(results)):
                    single_result = results[ind]
                    img_id = labels[ind]['image_id'].cpu().item()
                    img_data = data.img_val[img_id]
                    filename = img_data['file_name']
                    height, width = img_data['height'], img_data['width']
                    img = img_data['image']
                    orig_boxes = [elem['bbox'] for elem in img_data['objects']]

                    for i in range(len(iou_threshold)):
                        ind = torchvision.ops.nms(single_result['boxes'].cpu(), single_result['scores'].cpu(), iou_threshold[i])
                        address = os.path.join(nms_addr, filename[:-4])
                        addr.create_dir([address])
                        plot_results(img, single_result['boxes'][ind], orig_boxes, os.path.join(address, f"{iou_threshold[i]}.png"))
                        if iou_threshold[i] == 0:
                            with open(os.path.join(pred_addr, f'{filename[:-4]}.txt'), 'w') as file:
                                for box, score in zip(single_result['boxes'][ind].cpu().tolist(), single_result['scores'].cpu().tolist()):
                                    x = (box[0] + box[2])/(2*width)
                                    y = (box[1] + box[3])/(2*height)
                                    w = (box[2] - box[0])/(width)
                                    h = (box[3] - box[1])/(height)
                                    file.write(f'0 {x} {y} {w} {h} {score}\n')

    def generate_heatmap(self):
        self.model.eval()

        with torch.no_grad():
            for batch in data.loader_val:
                # Copying data to cuda
                pixel_values = batch["pixel_values"].to(self.device)
                pixel_mask = batch["pixel_mask"].to(self.device)
                labels = [{k: v for k, v in t.items()} for t in batch["labels"]]
                target_sizes = torch.stack([target['orig_size'] for target in labels], dim = 0)

                conv_features, enc_attn_weights, dec_attn_weights = [], [], []

                hooks = [
                    self.model.model.backbone.conv_encoder.model.encoder.stages[-1].register_forward_hook(
                        lambda self, input, output: conv_features.append(output)
                    ),
                    self.model.model.encoder.layers[-1].self_attn.register_forward_hook(
                        lambda self, input, output: enc_attn_weights.append(output[1])
                    ),
                    self.model.model.decoder.layers[-1].self_attn.register_forward_hook(
                        lambda self, input, output: dec_attn_weights.append(output[1])
                    ),
                ]

                # Forward Propagation
                outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)
                results = data.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0)
                bboxes_scaled = [[result['boxes'][torch.argmax(result['scores'])]] for result in results]

                for hook in hooks:
                    hook.remove()
                
                # don't need the list anymore
                conv_features = conv_features[0]
                enc_attn_weights = enc_attn_weights[0]
                dec_attn_weights = dec_attn_weights[0]

                # print(enc_attn_weights.shape)
                print(dec_attn_weights.shape)
                break
                
                for ind in range(conv_features.shape[0]):
                    h, w = conv_features[ind].shape[-2:]

                    fig, axs = plt.subplots(ncols=len(bboxes_scaled[ind]), nrows=2, figsize=(22, 7))
                    for ax_i, (xmin, ymin, xmax, ymax) in zip(axs.T, bboxes_scaled[ind]):
                        ax = ax_i[0]
                        ax.imshow(dec_attn_weights[0, idx].view(h, w))
                        ax.axis('off')
                        ax.set_title(f'query id: {idx.item()}')
                        ax = ax_i[1]
                        ax.imshow(im)
                        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                                fill=False, color='blue', linewidth=3))
                        ax.axis('off')
                        ax.set_title(CLASSES[probas[idx].argmax()])
                    fig.tight_layout()

                break

    def plot_loss(self, addr = None):
        '''
        Plots Loss vs number of epochs
        '''
        if not os.path.exists(self.loss_addr):
            raise Exception("No Loss Array")
        loss_arr = np.load(self.loss_addr)
        train_arr, val_arr = loss_arr['train'], loss_arr['val']
        num_epoch = train_arr.shape[0]
        x_arr = np.linspace(1, num_epoch, num_epoch)

        if addr is None:
            addr = os.path.join(self.model_addr, 'loss_curve')

        plt.title("Loss Curve")
        plt.xlabel("Number of Epochs")
        plt.ylabel("Loss")
        plt.plot(x_arr, train_arr, label='Train')
        plt.plot(x_arr, val_arr, label='Val')
        plt.legend()
        plt.savefig(addr)

    def best_model(self):
        '''
        Returns Best Model as well as changes self.model in place to best model
        '''
        if not os.path.exists(self.loss_addr):
            raise Exception("No Loss Array")
        loss_arr = np.load(self.loss_addr)
        best_epoch = np.argmin(loss_arr['val'])
        self.model.load_state_dict(torch.load(self.epoch_addr(best_epoch)), strict=False)
        return self.model

    def create_report(self):
        self.param.create_report(self.model_addr)

model = transformers.DetrForObjectDetection.from_pretrained(
    model_checkpoint,
    revision = "no_timm",
    num_labels = len(data.id2label),
    ignore_mismatched_sizes=True,
).to(device)

learner = LearnModel(model, addr.model_detr)
# learner.create_report() 
# learner.train()
# learner.plot_loss()
# learner.best_model()
learner.infer()
# learner.generate_heatmap()

Some weights of DetrForObjectDetection were not initialized from the model checkpoint at facebook/detr-resnet-50 and are newly initialized because the shapes did not match:
- class_labels_classifier.weight: found shape torch.Size([92, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- class_labels_classifier.bias: found shape torch.Size([92]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


43


DetrForObjectDetection(
  (model): DetrModel(
    (backbone): DetrConvModel(
      (conv_encoder): DetrConvEncoder(
        (model): ResNetBackbone(
          (embedder): ResNetEmbeddings(
            (embedder): ResNetConvLayer(
              (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
              (normalization): DetrFrozenBatchNorm2d()
              (activation): ReLU()
            )
            (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          )
          (encoder): ResNetEncoder(
            (stages): ModuleList(
              (0): ResNetStage(
                (layers): Sequential(
                  (0): ResNetBottleNeckLayer(
                    (shortcut): ResNetShortCut(
                      (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                      (normalization): DetrFrozenBatchNorm2d()
                    )
                    (layer): Seq