In [None]:
import torch
from torchvision import models
from torch import nn
from torch.utils.data.dataset import Dataset
from xml.etree import ElementTree
from zipfile import ZipFile
from io import BytesIO
import numpy as np
import cv2
from tqdm import tqdm
import pickle
from pathlib import Path
from comet_ml import Experiment as Comet
import torch
from tqdm import tqdm
import gc
import datetime
from pathlib import Path
import numpy as np
from pathlib import Path
from albumentations import HorizontalFlip, ShiftScaleRotate, RandomBrightnessContrast, Compose
import cv2
import numpy as np
from pathlib import Path


class Data(Dataset):
    """
    A PyTorch Dataset class to load and process the Plant Tracer data

    Parameters
    ----------
    path: Path
        Path to location of frames and annotations
        The videos are expected as ZIP file containing all the frames
        The annotations are expected in PASCAL_VOC format
    target_size: int
        The size of image to use when training the model
        The image will be of size (@target_size x @target_size)
    transforms: bool
        Apply augmentation transformations on the image
    """

    def __init__(self, path: Path, target_size: int, transforms: bool = True):
        self.target_size = target_size
        self.transforms = transforms
        all_frames = list(path.glob('*.zip'))
        all_annotations = list(path.glob('*.xml'))
        self.video_annotations = [self.load_annotations(a) for a in
                                  tqdm(all_annotations, desc='Loading Annotations')]
        self.video_frames = [self.load_frames(f) for f in tqdm(all_frames, desc='Loading Video Frames')]
        self.__sanity_check__()
        if self.transforms:
            self.prev_hflip = Compose([HorizontalFlip(p=1)])
            self.prev_ssr = Compose([ShiftScaleRotate(p=1)])
            self.prev_bri = Compose([RandomBrightnessContrast(p=1)])
            self.curr_hflip = Compose([HorizontalFlip(p=1)],
                                      bbox_params={'format': 'pascal_voc', 'label_fields': ['category_id']})
            self.curr_ssr = Compose([ShiftScaleRotate(p=1)],
                                    bbox_params={'format': 'pascal_voc', 'label_fields': ['category_id']})
            self.curr_bri = Compose([RandomBrightnessContrast(p=1)],
                                    bbox_params={'format': 'pascal_voc', 'label_fields': ['category_id']})

    def __len__(self):
        """
        Calculate the total length of dataset
        Returns
        -------
            int: Length of dataset
        """
        return sum([len(videos) - 1 for videos in self.video_frames])

    def __getitem__(self, index):
        """
        Retrieve a random item from the dataset

        Returns
        -------
            ndarray: The previous frame cropped based on annotation
            ndarray: The current frame cropped based on previous annotation
            ndarray: The current annotation rescaled
            list: The scale used to crop resize the frames
            list: The amount of crop performed on the frames
        """
        # Select a random video and then, select a random frame in that video
        np.random.seed(seed=index)
        vi = np.random.randint(low=0, high=len(self.video_frames) - 1, size=1)[0]
        fi = np.random.randint(low=1, high=len(self.video_frames[vi]), size=1)[0]
        np.random.seed(seed=None)
        previous_frame = self.video_frames[vi][fi - 1]
        current_frame = self.video_frames[vi][fi]
        previous_annotation = self.video_annotations[vi][fi - 1]
        current_annotation = self.video_annotations[vi][fi]
        return self.make_crops(previous_frame, current_frame, previous_annotation, current_annotation)

    def __sanity_check__(self):
        """
        Make a check to see if number of videos and annotations match
        Raises
        -------
        ValueError
            If the number of video frames does not match the number of annotations, or
            if each video does not have its respective annotations
        """
        if len(self.video_annotations) != len(self.video_frames):
            raise ValueError('Sizes of annotations and videos do not match')
        # Also make a check to see if number of frames in each video and number of annotations for each video match
        else:
            for i, d in enumerate(zip(self.video_annotations, self.video_frames)):
                if d[0].shape[0] != d[1].shape[0]:
                    raise ValueError('Sizes of annotations and videos do not match in {}'.format(i + 1))

    @staticmethod
    def load_pickles(path: Path):
        """
        Load the video frames and annotations using pickle files

        Parameters
        ----------
        path: Path
            Path to location of pickle file which contains the video frames and annotations

        Returns
        -------
            Video annotations and video frames extracted from the pickle file
        """
        pickles = list(path.glob('*.pkl'))
        video_annotations = []
        video_frames = []
        for p in tqdm(pickles, desc='Loading Pickles'):
            with open(p, 'rb') as pkl:
                save_dict = pickle.load(pkl)
                video_annotations.append(save_dict['annotations'])
                video_frames.append(save_dict['frames'])
        return video_annotations, video_frames

    @staticmethod
    def load_annotations(path: Path):
        """
        Load the annotations from XML file
        Parameters
        ----------
        path: Path
            Path pointing to the XML file which contains the annotations
        Returns
        -------
            ndarray: The annotations loaded as an array
        """
        #
        root = ElementTree.parse(path).getroot()
        polygons = root[3].findall('polygon')
        buffer = np.empty((len(polygons), 4), np.int)
        for i, polygon in enumerate(polygons):
            pts = polygon.findall('pt')
            buffer[i][0] = int(pts[1][0].text)  # top-left x
            buffer[i][1] = int(pts[1][1].text)  # top-left y
            buffer[i][2] = int(pts[3][0].text)  # bottom-right x
            buffer[i][3] = int(pts[3][1].text)  # bottom-right y
        return buffer

    @staticmethod
    def load_frames(path: Path):
        """
        Load the frames from the ZIP file
        Parameters
        ----------
        path: Path
            Path pointing to the ZIP file which contains the video frames
        Returns
        -------
            ndarray: The video frames loaded as an array
        """
        zip_file = ZipFile(path)
        names = zip_file.namelist()
        buf = [cv2.imdecode(np.frombuffer(BytesIO(zip_file.open(name).read()).read(), np.uint8), 1) for name in names]
        return np.array(buf)

    @staticmethod
    def load_video(path: Path):
        """
        Load video frames from video file
        Parameters
        ----------
        path: Path
            Path pointing to the video file which has to be loaded
        Returns
        -------
            ndarray: The video frames loaded as an array
        """
        #
        cap = cv2.VideoCapture(str(path.resolve()))
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        buffer = np.empty((frame_count, frame_height, frame_width, 3), np.uint8)
        fc = 0
        ret = True
        while fc < frame_count and ret:
            ret, buffer[fc] = cap.read()
            fc += 1
        cap.release()
        return buffer

    @staticmethod
    def convert_channels(image: np.ndarray, channel_first: bool = False, channel_last: bool = False):
        """
        Converts image to channel first or channel last format

        Parameters
        ----------
        image: ndarray
            The image array
        channel_first: bool
            Set this to True to convert from channel last to channel first
        channel_last: bool
            Set this to True to convert from channel first to channel last

        Returns
        -------
            ndarray: Channel converted image
        """
        return np.moveaxis(image, -1, 0) if channel_first else (np.moveaxis(image, 0, -1) if channel_last else image)

    @staticmethod
    def get_bbox(preds: np.ndarray, target_size: int, scale: list, crop: list):
        """
        Recovers the bounding boxes after prediction

        Parameters
        ----------
        preds: ndarray
            The predicted bounding box coordinates
        target_size: int
            The image size
        scale: list
            The scale used when rescaling the bounding boxes
        crop: list
            The crop used when resizing the bounding boxes
        Returns
        -------
            ndarray: Recovered bounding box
        """
        preds = np.divide(preds, 10)
        preds = np.multiply(preds, target_size)
        preds = np.divide(preds, scale * 2)
        preds = np.add(preds, crop * 2)
        return preds.astype(int)

    def make_crops(self, previous_frame: np.ndarray,
                   current_frame: np.ndarray,
                   previous_annotation: np.ndarray,
                   current_annotation: np.ndarray,
                   validate: bool = False):
        """

        Parameters
        ----------
        previous_frame: ndarray
            The previous frame of the selected video frame
        current_frame: ndarray
            The current video frame
        previous_annotation: ndarray
            The annotation for the previous video frame
        current_annotation: ndarray
            The annotation for the current video frame
        validate: bool, default = False
            Will not perform image transformation if set to True
        Returns
        -------
            ndarray: The previous frame cropped based on annotation
            ndarray: The current frame cropped based on previous annotation
            ndarray: The current annotation rescaled
            list: The scale used to crop resize the frames
            list: The amount of crop performed on the frames
        """
        # Clip the bounding box to makes it lies inside the image
        previous_annotation[0] = np.clip(previous_annotation[0], 0, previous_frame.shape[0])
        previous_annotation[1] = np.clip(previous_annotation[1], 0, previous_frame.shape[1])
        previous_annotation[2] = np.clip(previous_annotation[2], 0, previous_frame.shape[0])
        previous_annotation[3] = np.clip(previous_annotation[3], 0, previous_frame.shape[1])

        center_x = int((previous_annotation[0] + previous_annotation[2]) / 2)
        center_y = int((previous_annotation[1] + previous_annotation[3]) / 2)
        width = abs(previous_annotation[2] - previous_annotation[0])
        height = abs(previous_annotation[3] - previous_annotation[1])

        # Create a crop window of size 7 x max(height, width) of the image
        crop_size = np.clip(7 * max(width, height), 10, 120)
        top = np.clip(center_x + crop_size, 0, previous_frame.shape[0])
        left = np.clip(center_y - crop_size, 0, previous_frame.shape[1])
        bottom = np.clip(center_x - crop_size, 0, previous_frame.shape[0])
        right = np.clip(center_y + crop_size, 0, previous_frame.shape[1])
        crop = [bottom, left]

        # Generate the cropped images
        previous_cropped = previous_frame[left: right, bottom: top, :]
        current_cropped = current_frame[left: right, bottom: top, :]

        # Calculate the scale needed to resize the image to :target_size:
        scale = np.divide(self.target_size, current_cropped.shape[:-1]).tolist()
        previous_cropped = cv2.resize(previous_cropped, (self.target_size, self.target_size))
        current_cropped = cv2.resize(current_cropped, (self.target_size, self.target_size))

        # Scale and crop the bounding box appropriately
        bbox = np.subtract(current_annotation, crop * 2)
        bbox = np.multiply(bbox, scale * 2)
        bbox = np.divide(bbox, self.target_size)

        # Apply transformations
        if not validate:
            if self.transforms:
                try:
                    x_min, y_max, x_max, y_min = bbox
                    bbox = np.array([x_min, y_min, x_max, y_max])
                    previous_augmented = {'image': previous_cropped}
                    current_augmented = {'image': current_cropped, 'bboxes': [bbox], 'category_id': [0]}
                    if np.random.random() > 0.5:
                        try:
                            previous_augmented = self.prev_hflip(**previous_augmented)
                            current_augmented = self.curr_hflip(**current_augmented)
                        except Exception as e:
                            print(e)
                    if np.random.random() > 0.5:
                        try:
                            previous_augmented = self.prev_ssr(**previous_augmented)
                            current_augmented = self.curr_ssr(**current_augmented)
                        except Exception as e:
                            print(e)
                    if np.random.random() > 0.5:
                        try:
                            previous_augmented = self.prev_bri(**previous_augmented)
                            current_augmented = self.curr_bri(**current_augmented)
                        except Exception as e:
                            print(e)
                    previous_cropped = previous_augmented['image']
                    current_cropped = current_augmented['image']
                    x_min, y_min, x_max, y_max = bbox
                    bbox = np.array([x_min, y_max, x_max, y_min])
                except Exception as e:
                    print(e)

        # Convert images from channel last to channel first format
        previous_cropped = self.convert_channels(previous_cropped, channel_first=True)
        current_cropped = self.convert_channels(current_cropped, channel_first=True)

        # Multiply the bounding box by 10
        bbox = np.multiply(bbox, 10)
        return previous_cropped, current_cropped, bbox, scale, crop


In [None]:


class GoTurnRemix(nn.Module):
    """
        Create a model based on GOTURN. The GOTURN architecture used a CaffeNet while GoTurnRemix uses AlexNet.
        The rest of the architecture is the similar to GOTURN. A PyTorch implementation of GOTURN can be found at:

        https://github.com/aakaashjois/PyTorch-GOTURN
    """

    def __init__(self):
        super(GoTurnRemix, self).__init__()
        # Load an AlexNet model pretrained on ImageNet
        self.features = nn.Sequential(*list(models.alexnet(pretrained=True).children())[:-1])
        # Freeze the pretrained layers
        for param in self.features.parameters():
            param.requires_grad = False
        self.regressor = nn.Sequential(
            nn.Linear(256 * 6 * 6 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4),
        )
        # Initialize the biases of the Linear layers to 1
        # Initialize weights to a normal distribution with 0 mean and 0.005 standard deviation
        for m in self.regressor.modules():
            if isinstance(m, nn.Linear):
                m.bias.data.fill_(1)
                m.weight.data.normal_(0, 0.005)

    def forward(self, previous, current):
        previous_features = self.features(previous)
        current_features = self.features(current)
        # Flatten, concatenate and pass to regressor the features
        return self.regressor(torch.cat((previous_features.view(previous_features.size(0), 256 * 6 * 6),
                                          current_features.view(current_features.size(0), 256 * 6 * 6)), 1))


In [None]:



class Experiment:
    """
        A helper class to facilitate the training and validation procedure of the GoTurnRemix model

        Parameters
        ----------
        learning_rate: float
            Learning rate to train the model. The optimizer is SGD and the loss is L1 Loss
        image_size: int
            The size of the input image. This has to be fixed before the data is created
        data_path: Path
            Path to the data folder. If the folder name includes "pickle", then the data saved as pickles are loaded
        augment: bool
            Perform augmentation on the images before training
        logs_path: Path
            Path to save the validation predictions at the end of each epoch
        models_path: Path
            Path to save the model state at the end of each epoch
        save_name: str
            Name of the folder in which the logs and models are saved. If not provided, the current datetime is used
    """

    def __init__(self,
                 learning_rate: float,
                 image_size: int,
                 data_path: Path,
                 augment: bool = True,
                 logs_path: Path = None,
                 models_path: Path = None,
                 save_name: str = None,
                 comet_api: str = None):
        self.image_size = image_size
        self.logs_path = logs_path
        self.models_path = models_path
        self.model = GoTurnRemix()
#         self.model.cuda()
        self.criterion = torch.nn.L1Loss()
        self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)
        self.model_name = str(datetime.datetime.now()).split('.')[0].replace(':', '-').replace(' ', '-')
        self.model_name = save_name if save_name else self.model_name
        self.augment = augment
        self.data = Data(data_path, target_size=self.image_size, transforms=augment)
        self.comet = None
        if comet_api:
            self.comet = Comet(api_key=comet_api)
            self.comet.log_parameter('learning_rate', learning_rate)
            self.comet.log_parameter('image_size', image_size)
            self.comet.log_parameter('augment', augment)

    def __train_step__(self, data):
        """
        Performs one step of the training procedure

        Parameters
        ----------
        data
            data obtained from @Data.__getitem__

        Returns
        -------
           Loss at the end of training step
        """
        if self.comet:
            self.comet.train()
        previous_cropped, current_cropped, bbox, scale, crop = data
        previous_cropped = torch.div(previous_cropped, 255).float()
        current_cropped = torch.div(current_cropped, 255).float()
        previous_cropped = torch.autograd.Variable(previous_cropped, requires_grad=True)
        current_cropped = torch.autograd.Variable(current_cropped, requires_grad=True)
        bbox = bbox.requires_grad_(True).float()
        self.optimizer.zero_grad()
        preds = self.model(previous_cropped, current_cropped)

        del previous_cropped
        del current_cropped
        gc.collect()

        loss = self.criterion(preds, bbox)
        if self.comet:
            self.comet.log_metric('loss', loss)
        loss.backward()
        self.optimizer.step()
        return loss

    def __test__(self):
        """
        Test tracking of the model

        Returns
        -------
            Test loss and test predictions
        """
        # Set model to evaluation mode
        if self.comet:
            self.comet.test()
        self.model.eval()
        test_preds = []
        test_loss = []
        video_frames = self.data.video_frames[-1]
        video_annotations = self.data.video_annotations[-1]
        p_a = video_annotations[0]
        p_f = video_frames[0]
        test_preds.append(p_a)

        for i in tqdm(range(1, len(video_annotations)), desc='Validating'):
            c_a = video_annotations[i]
            c_f = video_frames[i]
            p_c, c_c, bbox, scale, crop = self.data.make_crops(p_f, c_f, p_a, c_a)
            p_c = torch.div(torch.from_numpy(p_c), 255).unsqueeze(0).float()
            c_c = torch.div(torch.from_numpy(c_c), 255).unsqueeze(0).float()
            bbox = torch.tensor(bbox, requires_grad=False).float()
            preds = self.model(p_c, c_c)

            del p_c
            del c_c
            gc.collect()

            loss = torch.nn.functional.l1_loss(preds, bbox)
            if self.comet:
                self.comet.log_metric('val_loss', loss)
            test_loss.append(loss.item())
            preds = self.data.get_bbox(preds.cpu().detach().numpy()[0], self.image_size, scale, crop)
            test_preds.append(preds)
            p_a = preds
            p_f = c_f
        return test_loss, test_preds

    def __validate__(self):
        """
        Performs validation on the model

        Returns
        -------
            Validation loss and validation predictions
        """
        # Set model to evaluation mode
        if self.comet:
            self.comet.validate()
        self.model.eval()
        validation_preds = []
        validation_loss = []
        video_frames = self.data.video_frames[-1]
        video_annotations = self.data.video_annotations[-1]
        p_a = video_annotations[0]
        p_f = video_frames[0]
        validation_preds.append(p_a)

        for i in tqdm(range(1, len(video_annotations)), desc='Validating'):
            c_a = video_annotations[i]
            c_f = video_frames[i]
            p_c, c_c, bbox, scale, crop = self.data.make_crops(p_f, c_f, p_a, c_a)
            p_c = torch.div(torch.from_numpy(p_c), 255).unsqueeze(0).float()
            c_c = torch.div(torch.from_numpy(c_c), 255).unsqueeze(0).float()
            bbox = torch.tensor(bbox, requires_grad=False).float()
            preds = self.model(p_c, c_c)

            del p_c
            del c_c
            gc.collect()

            loss = torch.nn.functional.l1_loss(preds, bbox)
            if self.comet:
                self.comet.log_metric('val_loss', loss)
            validation_loss.append(loss.item())
            preds = self.data.get_bbox(preds.cpu().detach().numpy()[0], self.image_size, scale, crop)
            validation_preds.append(preds)
            p_a = c_a
            p_f = c_f
        return validation_loss, validation_preds

    def train(self, epochs: int, batch_size: int, validate: bool = True, test: bool = True):
        """
        Trains the model for @epochs number of epochs

        Parameters
        ----------
        epochs: int
            Number of epochs to train the model
        batch_size: int
            The size of each batch when training the model
        validate: bool, default=True
            If True, validation occurs at the end of each epoch
            The results are saved in @logs_path and models are saved in @models_path
        test: bool, default=True
            If True, the model is tested for tracking at the end of the training procedure
            The results are saved in @logs_path

        Returns
        -------
            list: List containing the training loss at the end of each epoch
        """
        if self.comet:
            self.comet.log_parameter('epochs', epochs)
            self.comet.log_parameter('batch_size', batch_size)
        loss_per_epoch = []
        preds_per_epoch = []
        # Set the model to training mode
        self.model.train()
        # Create a DataLoader to feed data to the model
        dataloader = torch.utils.data.DataLoader(dataset=self.data, batch_size=batch_size, shuffle=True)

        # Run for @epochs number of epochs
        for epoch in range(epochs):
            if self.comet:
                self.comet.log_metric('epoch', epoch)
            running_loss = []
            for step, data in enumerate(tqdm(dataloader,
                                             total=int(len(self.data) / batch_size),
                                             desc='Epoch {}'.format(epoch))):
                loss = self.__train_step__(data)
                running_loss.append(loss.item())
            training_loss = sum(running_loss) / len(running_loss)
            if self.comet:
                self.comet.log_metric('mean_train_loss', training_loss)
            loss_per_epoch.append(sum(running_loss) / len(running_loss))
            if validate:
                validation_loss, validation_preds = self.__validate__()
                if self.comet:
                    self.comet.log_metric('mean_validation_loss', validation_loss)
                preds_per_epoch.append(validation_preds)
                print('Validation loss: {}'.format(sum(validation_loss) / len(validation_loss)))
            # Save the model at this stage
            if self.models_path:
                (self.models_path / self.model_name).mkdir(exist_ok=True)
                torch.save(self.model, (self.models_path / self.model_name / 'epoch_{}'.format(epoch)).resolve())
            print('Training Loss: {}'.format(training_loss))
        # Save the validation frames, ground truths and predictions at this stage
        if self.logs_path:
            (self.logs_path / self.model_name).mkdir(exist_ok=True)
            save = {'frames': self.data.video_frames[-1],
                    'truth': self.data.video_annotations[-1],
                    'preds': preds_per_epoch}
            np.save(str((self.logs_path / self.model_name / 'preds_per_epoch.npy').resolve()), save)
        # Test the model and save the results
        if test:
            test_loss, test_preds = self.__test__()
            if self.logs_path:
                (self.logs_path / self.model_name).mkdir(exist_ok=True)
                save = {'frames': self.data.video_frames[-1],
                        'truth': self.data.video_annotations[-1],
                        'preds': test_preds,
                        'loss': test_loss}
                np.save(str((self.logs_path / self.model_name / 'test_preds.npy').resolve()), save)
        return loss_per_epoch


In [None]:


# Set the path to load data
data_path = Path('../plant-data')
# Set the path to save predictions
logs_path = Path('/logs')
# Set the path to save models
models_path = Path('../models')

# Modify the hyperparamters as needed
learning_rate = 1e-5
epochs = 50
batch_size = 128
image_size = 225
comet_api = 'BznUqlKMJs3tCDblJ9RnB7TiC'

# Create an experiment object
goturn_exp = Experiment(learning_rate=learning_rate,
                        image_size=image_size,
                        data_path=data_path,
                        augment=False,
                        logs_path=logs_path,
                        models_path=models_path,
                        comet_api=comet_api)

# Train the model
losses = goturn_exp.train(epochs, batch_size, validate=True, test=True)


In [2]:


path_to_validation_results = Path('../models/2022-05-20-15-02-19/epoch_49') 
preds_per_epoch = np.load(path_to_validation_results.resolve()).reshape(1, )[0]
frames = preds_per_epoch['frames']
truth = preds_per_epoch['truth']
preds = preds_per_epoch['preds']

use_centers = False

imgs = []
for f, t, p in zip(frames, truth, preds[-1]):
    img = f.copy()
    img = cv2.rectangle(img, (t[0], t[1]), (t[2], t[3]), (255, 255, 255), 2)
    img = cv2.rectangle(img, (p[0], p[1]), (p[2], p[3]), (255, 0, 0), 2)
    img = cv2.rectangle(img, (p[0], p[1]), (p[2], p[3]), (255, 0, 0), 2)
    imgs.append(img)
    cv2.imshow('Tracking', img)
    if cv2.waitKey(25) & 0xFF == ord('q'):
        break
cv2.destroyAllWindows()

height, width, channels = imgs[0].shape
video = cv2.VideoWriter('video.avi', cv2.VideoWriter_fourcc(*"MJPG"), 60, (width, height))
for img in imgs:
    video.write(img)
video.release()


AttributeError: 'NpzFile' object has no attribute 'reshape'