In [1]:
import sys
sys.path.append("..")
from general import *

In [2]:
from image_feature_extraction import ImageFeatureExtractor
from train_frame_prediction import FramePredictor_Trainer

In [3]:
feat_ext = ImageFeatureExtractor(useGPU = False)

In [4]:
from AutoEncoders.C2D_Models import *

In [5]:
model = C2D_AE_128_3x3(channels=1)
load_model(model, "../AutoEncoders/C2D_AE_models/C2D_AE_128_3x3_UCSD2/C2D_AE_128_3x3_UCSD2.pth.tar")
model.cpu()

C2D_AE_128_3x3(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), st

In [6]:
feat_ext.feature_extractor = model

In [7]:
def extract_features(images):
    with torch.no_grad():
        return feat_ext.feature_extractor(images.to(feat_ext.device))[-1].detach().cpu().flatten(start_dim = 1, end_dim = -1)

# Train

In [8]:
ucsd_train = UCSD(2, asImages = False, image_size=128, n_frames=16, sample_stride = 2)

16it [00:47,  2.95s/it]


In [None]:
ucsd_processed_train = [(extract_features(data.transpose(0,1)), label) for (data, label) in ucsd_train]
train_loader, val_loader = get_data_loaders(ucsd_processed_train, batch_size=32)

In [None]:
try: del trainer
except: pass
trainer = FramePredictor_Trainer(256, 256, useGPU=False)

In [None]:
from frame_prediction import FrameFeaturePredictor

In [None]:
trainer.model = FrameFeaturePredictor(256,256, True, False)

In [None]:
trainer.train("C2D_LSTM_models/AE_LSTM_UCSD2.tar.pth", train_loader, val_loader, learning_rate=1e-4, epochs = 200)

In [None]:
print("DONE")

# Test

In [None]:
ucsd_test = UCSD(2, isTrain = False, image_size = 128, sample_stride=1)

In [None]:
from frame_prediction import FrameFeaturePredictor

In [None]:
test_model = FrameFeaturePredictor(256, 256, isTrain = False, useGPU=False)
load_model(test_model, "C2D_LSTM_models/AE_LSTM_UCSD2.tar.pth")

In [None]:
try: del tester
except: pass
tester = FramePredictor_Trainer(isTrain = False, useGPU = False)

In [13]:
def test_conv_features_lstm(self,
                            model,
                            feat_ext,
                            test_data,
                            batch_size = 8,
                            stackFrames = 16,
                            input_steps = 8,
                            save_as = False):
    model.to(self.device)
    overall_targets, overall_losses = list(), list()
    overall_roc_auc, overall_regularity_scores = list(), list()
    features = list()
    for directory_inputs, directory_labels in tqdm(test_data):
        directory_targets, directory_loss = list(), list()

        directory_input_features = list()
        for idx in range(0, len(directory_inputs), batch_size):
            extracted_features = feat_ext.extract_features(torch.stack(directory_inputs[idx: (idx + batch_size)]))
            directory_input_features += extracted_features
        directory_input_features = torch.stack(directory_input_features)
        
        for start_idx in range(0, (len(directory_input_features)//stackFrames)*stackFrames, stackFrames):
            test_inputs = directory_input_features[start_idx : (start_idx + stackFrames)] # 16, 1, 128, 128
            test_labels = directory_labels[start_idx : (start_idx + stackFrames)]
            test_inputs = test_inputs.unsqueeze(dim = 1).to(self.device)
            outputs = model.unroll(test_inputs[:input_steps], future_steps = (stackFrames - input_steps))
            loss = self.loss_criterion(test_inputs[1:], outputs[:-1])

            directory_loss += loss
            directory_targets += test_labels[1:]
        
        regularity_scores = loss_to_regularity(directory_loss)
        try:
            directory_roc_auc = roc_auc_score(directory_targets, regularity_scores)
        except:
            directory_roc_auc = 1.0
        overall_roc_auc.append(directory_roc_auc)
        overall_regularity_scores.append(regularity_scores)

        overall_targets.append(directory_targets)
        overall_losses.append(directory_loss)
#             overall_encodings.append(directory_encodings)
    overall_targets = np.array(overall_targets)
    overall_losses = np.array(overall_losses)
#     overall_encodings = np.array(overall_encodings)

    mean_roc_auc = np.mean(overall_roc_auc)

    self.results = {
        "targets": overall_targets,
        "losses": overall_losses,
        "regularity": overall_regularity_scores,
        "AUC_ROC_score": overall_roc_auc,
        "final_AUC_ROC":mean_roc_auc,
    }

    if save_as:
        with open(save_as, "wb") as f:
            pkl.dump(self.results, f)

    return mean_roc_auc

In [None]:
def test_conv_features_lstm(self,
                            model,
                            feat_ext,
                            test_data,
                            batch_size = 8,
                            stackFrames = 16,
                            input_steps = 8,
                            save_as = False):
    model.to(self.device)
    overall_targets, overall_losses = list(), list()
    overall_roc_auc, overall_regularity_scores = list(), list()
    features = list()
    for directory_inputs, directory_labels in tqdm(test_data):
        directory_targets, directory_loss = list(), list()

        directory_input_features = list()
        for idx in range(0, len(directory_inputs), batch_size):
            extracted_features = extract_features(torch.stack(directory_inputs[idx: (idx + batch_size)]))
            directory_input_features += extracted_features
        directory_input_features = torch.stack(directory_input_features)

        for start_idx in range(0, (len(directory_input_features)//stackFrames)*stackFrames, stackFrames):
            test_inputs = directory_input_features[start_idx : (start_idx + stackFrames)] # 16, 1, 128, 128
            test_labels = directory_labels[start_idx : (start_idx + stackFrames)]
            test_inputs = test_inputs.unsqueeze(dim = 1).to(self.device)
            outputs = model.unroll(test_inputs[:input_steps], future_steps = (stackFrames - input_steps))
            loss = self.loss_criterion(test_inputs[1:], outputs[:-1])

            directory_loss += loss
            directory_targets += test_labels[1:]

        regularity_scores = loss_to_regularity(directory_loss)
        try:
            directory_roc_auc = roc_auc_score(directory_targets, regularity_scores)
        except:
            directory_roc_auc = 1.0
        overall_roc_auc.append(directory_roc_auc)
        overall_regularity_scores.append(regularity_scores)

        overall_targets.append(directory_targets)
        overall_losses.append(directory_loss)
#             overall_encodings.append(directory_encodings)
    overall_targets = np.array(overall_targets)
    overall_losses = np.array(overall_losses)
#     overall_encodings = np.array(overall_encodings)

    mean_roc_auc = np.mean(overall_roc_auc)

    self.results = {
        "targets": overall_targets,
        "losses": overall_losses,
        "regularity": overall_regularity_scores,
        "AUC_ROC_score": overall_roc_auc,
        "final_AUC_ROC":mean_roc_auc,
    }

    if save_as:
        with open(save_as, "wb") as f:
            pkl.dump(self.results, f)

    return mean_roc_auc

In [None]:
test_conv_features_lstm(
            tester,
            test_model,
            feat_ext,
            ucsd_test,
            batch_size = 4,
            stackFrames = 16,
            input_steps = 8,
            save_as = False)

In [None]:
print("Testing done")