This repository has been archived by the owner on Nov 23, 2023. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
MFnet based 3d-conv heatmaps | ||
tested for pytorch version 0.4 | ||
""" | ||
import os | ||
import cv2 | ||
import torch | ||
import argparse | ||
import numpy as np | ||
from mfnet_3d import MFNET_3D | ||
from scipy.ndimage import zoom | ||
|
||
def center_crop(data, tw=224, th=224): | ||
h, w, c = data.shape | ||
x1 = int(round((w - tw) / 2.)) | ||
y1 = int(round((h - th) / 2.)) | ||
cropped_data = data[y1:(y1+th), x1:(x1+tw), :] | ||
return cropped_data | ||
|
||
def load_images(frame_dir, selected_frames): | ||
images = np.zeros((16, 224, 224, 3)) | ||
orig_imgs = np.zeros_like(images) | ||
for i, frame_name in enumerate(selected_frames): | ||
im_name = os.path.join(frame_dir, frame_name) | ||
next_image = cv2.imread(im_name, cv2.IMREAD_COLOR) | ||
scaled_img = cv2.resize(next_image, (256, 256), interpolation=cv2.INTER_LINEAR) # resize to 256x256 | ||
cropped_img = center_crop(scaled_img) # center crop 224x224 | ||
final_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB) | ||
images[i] = final_img | ||
orig_imgs[i] = cropped_img | ||
|
||
torch_imgs = torch.from_numpy(images.transpose(3,0,1,2)) | ||
torch_imgs = torch_imgs.float() / 255.0 | ||
mean_3d = [124 / 255, 117 / 255, 104 / 255] | ||
std_3d = [0.229, 0.224, 0.225] | ||
for t, m, s in zip(torch_imgs, mean_3d, std_3d): | ||
t.sub_(m).div_(s) | ||
return np.expand_dims(orig_imgs, 0), torch_imgs.unsqueeze(0) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='mfnet-base-parser') | ||
parser.add_argument("num_classes", type=int) | ||
parser.add_argument("model_weights", type=str) | ||
parser.add_argument("frame_dir", type=str) | ||
parser.add_argument("label", type=int) | ||
parser.add_argument("--base_output_dir", type=str, default=r"visualisations") | ||
return parser.parse_args() | ||
|
||
args = parse_args() | ||
|
||
frame_names = os.listdir(args.frame_dir) | ||
frame_indices = list(np.linspace(0, len(frame_names)-1, num=16, dtype=np.int)) | ||
selected_frames = [frame_names[i] for i in frame_indices] | ||
|
||
RGB_vid, vid = load_images(args.frame_dir, selected_frames) | ||
|
||
# load network structure, load weights, send to gpu, set to evaluation mode | ||
model_ft = MFNET_3D(args.num_classes) | ||
model_ft = torch.nn.DataParallel(model_ft).cuda() | ||
checkpoint = torch.load(args.model_weights, map_location={'cuda:1':'cuda:0'}) | ||
model_ft.load_state_dict(checkpoint['state_dict']) | ||
model_ft.eval() | ||
|
||
# get predictions, last convolution output and the weights of the prediction layer | ||
predictions, layerout = model_ft(torch.tensor(vid).cuda()) | ||
layerout = torch.tensor(layerout[0].numpy().transpose(1, 2, 3, 0)) | ||
pred_weights = model_ft.module.classifier.weight.data.detach().cpu().numpy().transpose() | ||
|
||
pred = torch.argmax(predictions).item() | ||
|
||
cam = np.zeros(dtype = np.float32, shape = layerout.shape[0:3]) | ||
for i, w in enumerate(pred_weights[:, args.label]): | ||
|
||
# Compute cam for every kernel | ||
cam += w * layerout[:, :, :, i] | ||
|
||
# Resize CAM to frame level | ||
cam = zoom(cam, (2, 32, 32)) # output map is 8x7x7, so multiply to get to 16x224x224 (original image size) | ||
|
||
# normalize | ||
cam -= np.min(cam) | ||
cam /= np.max(cam) - np.min(cam) | ||
|
||
# make dirs and filenames | ||
example_name = os.path.basename(args.frame_dir) | ||
heatmap_dir = os.path.join(args.base_output_dir, example_name, str(args.label), "heatmap") | ||
focusmap_dir = os.path.join(args.base_output_dir, example_name, str(args.label), "focusmap") | ||
for d in [heatmap_dir, focusmap_dir]: | ||
if not os.path.exists(d): | ||
os.makedirs(d) | ||
|
||
file = open(os.path.join(args.base_output_dir, example_name, str(args.label), "info.txt"),"a") | ||
file.write("Visualizing for class {}\n".format(args.label)) | ||
file.write("Predicted class {}\n".format(pred)) | ||
file.close() | ||
|
||
# produce heatmap and focusmap for every frame and activation map | ||
for i in range(0, cam.shape[0]): | ||
# Create colourmap | ||
heatmap = cv2.applyColorMap(np.uint8(255*cam[i]), cv2.COLORMAP_JET) | ||
# Create focus map | ||
focusmap = np.uint8(255*cam[i]) | ||
focusmap = cv2.normalize(cam[i], dst=focusmap, alpha=20, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) | ||
|
||
# Create frame with heatmap | ||
heatframe = heatmap//2 + RGB_vid[0][i]//2 | ||
cv2.imwrite(os.path.join(heatmap_dir,'{:03d}.png'.format(i)), heatframe) | ||
|
||
# Create frame with focus map in the alpha channel | ||
focusframe = RGB_vid[0][i] | ||
focusframe = cv2.cvtColor(np.uint8(focusframe), cv2.COLOR_BGR2BGRA) | ||
focusframe[:,:,3] = focusmap | ||
cv2.imwrite(os.path.join(focusmap_dir,'{:03d}.png'.format(i)), focusframe) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Original Author: Yunpeng Chen | ||
https://github.com/cypw/PyTorch-MFNet/blob/master/network/mfnet_3d.py | ||
""" | ||
|
||
from collections import OrderedDict | ||
import torch.nn as nn | ||
|
||
class BN_AC_CONV3D(nn.Module): | ||
|
||
def __init__(self, num_in, num_filter, | ||
kernel=(1,1,1), pad=(0,0,0), stride=(1,1,1), g=1, bias=False): | ||
super(BN_AC_CONV3D, self).__init__() | ||
self.bn = nn.BatchNorm3d(num_in) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv = nn.Conv3d(num_in, num_filter, kernel_size=kernel, padding=pad, | ||
stride=stride, groups=g, bias=bias) | ||
|
||
def forward(self, x): | ||
h = self.relu(self.bn(x)) | ||
h = self.conv(h) | ||
return h | ||
|
||
|
||
class MF_UNIT(nn.Module): | ||
|
||
def __init__(self, num_in, num_mid, num_out, g=1, stride=(1,1,1), first_block=False, use_3d=True): | ||
super(MF_UNIT, self).__init__() | ||
num_ix = int(num_mid/4) | ||
kt,pt = (3,1) if use_3d else (1,0) | ||
# prepare input | ||
self.conv_i1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_ix, kernel=(1,1,1), pad=(0,0,0)) | ||
self.conv_i2 = BN_AC_CONV3D(num_in=num_ix, num_filter=num_in, kernel=(1,1,1), pad=(0,0,0)) | ||
# main part | ||
self.conv_m1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_mid, kernel=(kt,3,3), pad=(pt,1,1), stride=stride, g=g) | ||
if first_block: | ||
self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0)) | ||
else: | ||
self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,3,3), pad=(0,1,1), g=g) | ||
# adapter | ||
if first_block: | ||
self.conv_w1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0), stride=stride) | ||
|
||
def forward(self, x): | ||
|
||
h = self.conv_i1(x) | ||
x_in = x + self.conv_i2(h) | ||
|
||
h = self.conv_m1(x_in) | ||
h = self.conv_m2(h) | ||
|
||
if hasattr(self, 'conv_w1'): | ||
x = self.conv_w1(x) | ||
|
||
return h + x | ||
|
||
|
||
class MFNET_3D(nn.Module): | ||
|
||
def __init__(self, num_classes, dropout=None, pretrained=False, pretrained_model="", **kwargs): | ||
super(MFNET_3D, self).__init__() | ||
|
||
groups = 16 | ||
k_sec = { 2: 3, \ | ||
3: 4, \ | ||
4: 6, \ | ||
5: 3 } | ||
|
||
# conv1 - x224 (x16) | ||
conv1_num_out = 16 | ||
self.conv1 = nn.Sequential(OrderedDict([ | ||
('conv', nn.Conv3d( 3, conv1_num_out, kernel_size=(3,5,5), padding=(1,2,2), stride=(1,2,2), bias=False)), | ||
('bn', nn.BatchNorm3d(conv1_num_out)), | ||
('relu', nn.ReLU(inplace=True)) | ||
])) | ||
self.maxpool = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) | ||
|
||
# conv2 - x56 (x8) | ||
num_mid = 96 | ||
conv2_num_out = 96 | ||
self.conv2 = nn.Sequential(OrderedDict([ | ||
("B%02d"%i, MF_UNIT(num_in=conv1_num_out if i==1 else conv2_num_out, | ||
num_mid=num_mid, | ||
num_out=conv2_num_out, | ||
stride=(2,1,1) if i==1 else (1,1,1), | ||
g=groups, | ||
first_block=(i==1))) for i in range(1,k_sec[2]+1) | ||
])) | ||
|
||
# conv3 - x28 (x8) | ||
num_mid *= 2 | ||
conv3_num_out = 2 * conv2_num_out | ||
self.conv3 = nn.Sequential(OrderedDict([ | ||
("B%02d"%i, MF_UNIT(num_in=conv2_num_out if i==1 else conv3_num_out, | ||
num_mid=num_mid, | ||
num_out=conv3_num_out, | ||
stride=(1,2,2) if i==1 else (1,1,1), | ||
g=groups, | ||
first_block=(i==1))) for i in range(1,k_sec[3]+1) | ||
])) | ||
|
||
# conv4 - x14 (x8) | ||
num_mid *= 2 | ||
conv4_num_out = 2 * conv3_num_out | ||
self.conv4 = nn.Sequential(OrderedDict([ | ||
("B%02d"%i, MF_UNIT(num_in=conv3_num_out if i==1 else conv4_num_out, | ||
num_mid=num_mid, | ||
num_out=conv4_num_out, | ||
stride=(1,2,2) if i==1 else (1,1,1), | ||
g=groups, | ||
first_block=(i==1))) for i in range(1,k_sec[4]+1) | ||
])) | ||
|
||
# conv5 - x7 (x8) | ||
num_mid *= 2 | ||
conv5_num_out = 2 * conv4_num_out | ||
self.conv5 = nn.Sequential(OrderedDict([ | ||
("B%02d"%i, MF_UNIT(num_in=conv4_num_out if i==1 else conv5_num_out, | ||
num_mid=num_mid, | ||
num_out=conv5_num_out, | ||
stride=(1,2,2) if i==1 else (1,1,1), | ||
g=groups, | ||
first_block=(i==1))) for i in range(1,k_sec[5]+1) | ||
])) | ||
|
||
# final | ||
self.tail = nn.Sequential(OrderedDict([ | ||
('bn', nn.BatchNorm3d(conv5_num_out)), | ||
('relu', nn.ReLU(inplace=True)) | ||
])) | ||
|
||
if dropout: | ||
self.globalpool = nn.Sequential(OrderedDict([ | ||
('avg', nn.AvgPool3d(kernel_size=(8,7,7), stride=(1,1,1))), | ||
('dropout', nn.Dropout(p=dropout)), | ||
])) | ||
else: | ||
self.globalpool = nn.Sequential(OrderedDict([ | ||
('avg', nn.AvgPool3d(kernel_size=(8,7,7), stride=(1,1,1))), | ||
# ('dropout', nn.Dropout(p=0.5)), only for fine-tuning | ||
])) | ||
self.classifier = nn.Linear(conv5_num_out, num_classes) | ||
|
||
def forward(self, x): | ||
assert x.shape[2] == 16 | ||
|
||
h = self.conv1(x) # x224 -> x112 | ||
h = self.maxpool(h) # x112 -> x56 | ||
|
||
h = self.conv2(h) # x56 -> x56 | ||
h = self.conv3(h) # x56 -> x28 | ||
h = self.conv4(h) # x28 -> x14 | ||
h = self.conv5(h) # x14 -> x7 | ||
h = self.tail(h) | ||
layerout = h.detach().cpu() | ||
h = self.globalpool(h) | ||
|
||
h = h.view(h.shape[0], -1) | ||
h = self.classifier(h) | ||
|
||
return h, layerout |