Skip to content
This repository has been archived by the owner on Nov 23, 2023. It is now read-only.

Commit

Permalink
added mfnet pytorch code
Browse files Browse the repository at this point in the history
  • Loading branch information
georkap committed Jan 28, 2019
1 parent 7819352 commit 9f68860
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 0 deletions.
116 changes: 116 additions & 0 deletions heat_tubes_mfnet_pytorch.py
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
"""
MFnet based 3d-conv heatmaps
tested for pytorch version 0.4
"""
import os
import cv2
import torch
import argparse
import numpy as np
from mfnet_3d import MFNET_3D
from scipy.ndimage import zoom

def center_crop(data, tw=224, th=224):
h, w, c = data.shape
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
cropped_data = data[y1:(y1+th), x1:(x1+tw), :]
return cropped_data

def load_images(frame_dir, selected_frames):
images = np.zeros((16, 224, 224, 3))
orig_imgs = np.zeros_like(images)
for i, frame_name in enumerate(selected_frames):
im_name = os.path.join(frame_dir, frame_name)
next_image = cv2.imread(im_name, cv2.IMREAD_COLOR)
scaled_img = cv2.resize(next_image, (256, 256), interpolation=cv2.INTER_LINEAR) # resize to 256x256
cropped_img = center_crop(scaled_img) # center crop 224x224
final_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB)
images[i] = final_img
orig_imgs[i] = cropped_img

torch_imgs = torch.from_numpy(images.transpose(3,0,1,2))
torch_imgs = torch_imgs.float() / 255.0
mean_3d = [124 / 255, 117 / 255, 104 / 255]
std_3d = [0.229, 0.224, 0.225]
for t, m, s in zip(torch_imgs, mean_3d, std_3d):
t.sub_(m).div_(s)
return np.expand_dims(orig_imgs, 0), torch_imgs.unsqueeze(0)


def parse_args():
parser = argparse.ArgumentParser(description='mfnet-base-parser')
parser.add_argument("num_classes", type=int)
parser.add_argument("model_weights", type=str)
parser.add_argument("frame_dir", type=str)
parser.add_argument("label", type=int)
parser.add_argument("--base_output_dir", type=str, default=r"visualisations")
return parser.parse_args()

args = parse_args()

frame_names = os.listdir(args.frame_dir)
frame_indices = list(np.linspace(0, len(frame_names)-1, num=16, dtype=np.int))
selected_frames = [frame_names[i] for i in frame_indices]

RGB_vid, vid = load_images(args.frame_dir, selected_frames)

# load network structure, load weights, send to gpu, set to evaluation mode
model_ft = MFNET_3D(args.num_classes)
model_ft = torch.nn.DataParallel(model_ft).cuda()
checkpoint = torch.load(args.model_weights, map_location={'cuda:1':'cuda:0'})
model_ft.load_state_dict(checkpoint['state_dict'])
model_ft.eval()

# get predictions, last convolution output and the weights of the prediction layer
predictions, layerout = model_ft(torch.tensor(vid).cuda())
layerout = torch.tensor(layerout[0].numpy().transpose(1, 2, 3, 0))
pred_weights = model_ft.module.classifier.weight.data.detach().cpu().numpy().transpose()

pred = torch.argmax(predictions).item()

cam = np.zeros(dtype = np.float32, shape = layerout.shape[0:3])
for i, w in enumerate(pred_weights[:, args.label]):

# Compute cam for every kernel
cam += w * layerout[:, :, :, i]

# Resize CAM to frame level
cam = zoom(cam, (2, 32, 32)) # output map is 8x7x7, so multiply to get to 16x224x224 (original image size)

# normalize
cam -= np.min(cam)
cam /= np.max(cam) - np.min(cam)

# make dirs and filenames
example_name = os.path.basename(args.frame_dir)
heatmap_dir = os.path.join(args.base_output_dir, example_name, str(args.label), "heatmap")
focusmap_dir = os.path.join(args.base_output_dir, example_name, str(args.label), "focusmap")
for d in [heatmap_dir, focusmap_dir]:
if not os.path.exists(d):
os.makedirs(d)

file = open(os.path.join(args.base_output_dir, example_name, str(args.label), "info.txt"),"a")
file.write("Visualizing for class {}\n".format(args.label))
file.write("Predicted class {}\n".format(pred))
file.close()

# produce heatmap and focusmap for every frame and activation map
for i in range(0, cam.shape[0]):
# Create colourmap
heatmap = cv2.applyColorMap(np.uint8(255*cam[i]), cv2.COLORMAP_JET)
# Create focus map
focusmap = np.uint8(255*cam[i])
focusmap = cv2.normalize(cam[i], dst=focusmap, alpha=20, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)

# Create frame with heatmap
heatframe = heatmap//2 + RGB_vid[0][i]//2
cv2.imwrite(os.path.join(heatmap_dir,'{:03d}.png'.format(i)), heatframe)

# Create frame with focus map in the alpha channel
focusframe = RGB_vid[0][i]
focusframe = cv2.cvtColor(np.uint8(focusframe), cv2.COLOR_BGR2BGRA)
focusframe[:,:,3] = focusmap
cv2.imwrite(os.path.join(focusmap_dir,'{:03d}.png'.format(i)), focusframe)

162 changes: 162 additions & 0 deletions mfnet_3d.py
@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
"""
Original Author: Yunpeng Chen
https://github.com/cypw/PyTorch-MFNet/blob/master/network/mfnet_3d.py
"""

from collections import OrderedDict
import torch.nn as nn

class BN_AC_CONV3D(nn.Module):

def __init__(self, num_in, num_filter,
kernel=(1,1,1), pad=(0,0,0), stride=(1,1,1), g=1, bias=False):
super(BN_AC_CONV3D, self).__init__()
self.bn = nn.BatchNorm3d(num_in)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv3d(num_in, num_filter, kernel_size=kernel, padding=pad,
stride=stride, groups=g, bias=bias)

def forward(self, x):
h = self.relu(self.bn(x))
h = self.conv(h)
return h


class MF_UNIT(nn.Module):

def __init__(self, num_in, num_mid, num_out, g=1, stride=(1,1,1), first_block=False, use_3d=True):
super(MF_UNIT, self).__init__()
num_ix = int(num_mid/4)
kt,pt = (3,1) if use_3d else (1,0)
# prepare input
self.conv_i1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_ix, kernel=(1,1,1), pad=(0,0,0))
self.conv_i2 = BN_AC_CONV3D(num_in=num_ix, num_filter=num_in, kernel=(1,1,1), pad=(0,0,0))
# main part
self.conv_m1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_mid, kernel=(kt,3,3), pad=(pt,1,1), stride=stride, g=g)
if first_block:
self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0))
else:
self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,3,3), pad=(0,1,1), g=g)
# adapter
if first_block:
self.conv_w1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0), stride=stride)

def forward(self, x):

h = self.conv_i1(x)
x_in = x + self.conv_i2(h)

h = self.conv_m1(x_in)
h = self.conv_m2(h)

if hasattr(self, 'conv_w1'):
x = self.conv_w1(x)

return h + x


class MFNET_3D(nn.Module):

def __init__(self, num_classes, dropout=None, pretrained=False, pretrained_model="", **kwargs):
super(MFNET_3D, self).__init__()

groups = 16
k_sec = { 2: 3, \
3: 4, \
4: 6, \
5: 3 }

# conv1 - x224 (x16)
conv1_num_out = 16
self.conv1 = nn.Sequential(OrderedDict([
('conv', nn.Conv3d( 3, conv1_num_out, kernel_size=(3,5,5), padding=(1,2,2), stride=(1,2,2), bias=False)),
('bn', nn.BatchNorm3d(conv1_num_out)),
('relu', nn.ReLU(inplace=True))
]))
self.maxpool = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))

# conv2 - x56 (x8)
num_mid = 96
conv2_num_out = 96
self.conv2 = nn.Sequential(OrderedDict([
("B%02d"%i, MF_UNIT(num_in=conv1_num_out if i==1 else conv2_num_out,
num_mid=num_mid,
num_out=conv2_num_out,
stride=(2,1,1) if i==1 else (1,1,1),
g=groups,
first_block=(i==1))) for i in range(1,k_sec[2]+1)
]))

# conv3 - x28 (x8)
num_mid *= 2
conv3_num_out = 2 * conv2_num_out
self.conv3 = nn.Sequential(OrderedDict([
("B%02d"%i, MF_UNIT(num_in=conv2_num_out if i==1 else conv3_num_out,
num_mid=num_mid,
num_out=conv3_num_out,
stride=(1,2,2) if i==1 else (1,1,1),
g=groups,
first_block=(i==1))) for i in range(1,k_sec[3]+1)
]))

# conv4 - x14 (x8)
num_mid *= 2
conv4_num_out = 2 * conv3_num_out
self.conv4 = nn.Sequential(OrderedDict([
("B%02d"%i, MF_UNIT(num_in=conv3_num_out if i==1 else conv4_num_out,
num_mid=num_mid,
num_out=conv4_num_out,
stride=(1,2,2) if i==1 else (1,1,1),
g=groups,
first_block=(i==1))) for i in range(1,k_sec[4]+1)
]))

# conv5 - x7 (x8)
num_mid *= 2
conv5_num_out = 2 * conv4_num_out
self.conv5 = nn.Sequential(OrderedDict([
("B%02d"%i, MF_UNIT(num_in=conv4_num_out if i==1 else conv5_num_out,
num_mid=num_mid,
num_out=conv5_num_out,
stride=(1,2,2) if i==1 else (1,1,1),
g=groups,
first_block=(i==1))) for i in range(1,k_sec[5]+1)
]))

# final
self.tail = nn.Sequential(OrderedDict([
('bn', nn.BatchNorm3d(conv5_num_out)),
('relu', nn.ReLU(inplace=True))
]))

if dropout:
self.globalpool = nn.Sequential(OrderedDict([
('avg', nn.AvgPool3d(kernel_size=(8,7,7), stride=(1,1,1))),
('dropout', nn.Dropout(p=dropout)),
]))
else:
self.globalpool = nn.Sequential(OrderedDict([
('avg', nn.AvgPool3d(kernel_size=(8,7,7), stride=(1,1,1))),
# ('dropout', nn.Dropout(p=0.5)), only for fine-tuning
]))
self.classifier = nn.Linear(conv5_num_out, num_classes)

def forward(self, x):
assert x.shape[2] == 16

h = self.conv1(x) # x224 -> x112
h = self.maxpool(h) # x112 -> x56

h = self.conv2(h) # x56 -> x56
h = self.conv3(h) # x56 -> x28
h = self.conv4(h) # x28 -> x14
h = self.conv5(h) # x14 -> x7
h = self.tail(h)
layerout = h.detach().cpu()
h = self.globalpool(h)

h = h.view(h.shape[0], -1)
h = self.classifier(h)

return h, layerout

0 comments on commit 9f68860

Please sign in to comment.