Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf layer multi gpu #8

Open
wants to merge 5 commits into
base: tf_layer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 100 additions & 30 deletions continued_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,64 @@
import numpy as np
import torch.nn as nn

class Model(nn.Module):
def __init__(self, args, clip_model, autoencoder, latent_flow_model, renderer, resizer, query_array):
super(Model, self).__init__()
self.autoencoder = autoencoder
self.latent_flow_model = latent_flow_model

self.clip_model = clip_model
#set clip model gradients to false
for param in self.clip_model.parameters():
param.requires_grad = False

self.renderer = renderer
self.resizer = resizer
self.query_array = query_array
self.args = args

def get_shapes(self):
self.autoencoder.train()
self.latent_flow_model.eval() # has to be in .eval() mode for the sampling to work (which is bad but whatever)

voxel_size = self.args.num_voxels
batch_size = len(self.query_array)

shape = (voxel_size, voxel_size, voxel_size)
p = visualization.make_3d_grid([-0.5] * 3, [+0.5] * 3, shape).type(torch.FloatTensor).to(self.args.device)
query_points = p.expand(batch_size, *p.size())

noise = torch.Tensor(batch_size, self.args.latent_dim).to(self.args.device)
decoder_embs = get_condition_embeddings(self.args, self.clip_model, self.query_array)

out_3d = self.autoencoder.decode(noise, decoder_embs, query_points)
return out_3d

def clip_loss(self, text_features):
out_3d = self.get_shapes()
out_3d_soft = torch.sigmoid(self.args.beta*(out_3d-self.args.threshold))

ims = self.renderer.render(out_3d_soft).double()
ims = self.resizer(ims)

img_embs = self.clip_model.encode_image(ims.type(self.args.visual_model_type))
if self.args.renderer=='ea':
#baseline renderer gives 3 dimensions
text_features=text_features.unsqueeze(1).expand(-1,3,-1).reshape(-1,512)

losses=-1*torch.cosine_similarity(text_features,im_embs)
loss = losses.mean()

# if self.args.use_tensorboard and not iter%50:
# im_samples= ims.view(-1,3,224,224)
# grid = torchvision.utils.make_grid(im_samples, nrow=3)
# self.args.writer.add_image('images', grid, iter)

return loss

def forward(self, text_features):
return self.clip_loss(text_features)

def get_type(visual_model):
return visual_model.conv1.weight.dtype

Expand Down Expand Up @@ -75,12 +133,7 @@ def get_clip_model(args):
input_resolution = clip_model.visual.input_resolution
#train_cond_embs_length = clip_model.train_cond_embs_length
vocab_size = clip_model.vocab_size
#cond_emb_dim = clip_model.embed_dim
#print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}")
print("cond_emb_dim:", cond_emb_dim)
print("Input resolution:", input_resolution)
#print("train_cond_embs length:", train_cond_embs_length)
print("Vocab size:", vocab_size)

args.n_px = input_resolution
args.cond_emb_dim = cond_emb_dim
return args, clip_model
Expand All @@ -98,7 +151,8 @@ def get_text_embeddings(args,clip_model,query_array):

return text_features

def gen_shapes(query_array,args,autoencoder,latent_flow_model,text_features):
def gen_shapes(query_array,args,visual_model,autoencoder,latent_flow_model,text_features):
# clip_model.eval()
autoencoder.train()
latent_flow_model.eval() # has to be in .eval() mode for the sampling to work (which is bad but whatever)

Expand All @@ -112,14 +166,16 @@ def gen_shapes(query_array,args,autoencoder,latent_flow_model,text_features):
noise = torch.Tensor(batch_size, args.emb_dims).normal_().to(args.device)
decoder_embs = latent_flow_model.sample(batch_size, noise=noise, cond_inputs=text_features)

out_3d = autoencoder.decoding(decoder_embs, query_points).view(batch_size, voxel_size, voxel_size, voxel_size).to(args.device)
# out_3d = autoencoder.decoding(decoder_embs, query_points).view(batch_size, voxel_size, voxel_size, voxel_size).to(args.device)
out_3d = autoencoder(decoder_embs, query_points).view(batch_size, voxel_size, voxel_size, voxel_size).to(args.device)
return out_3d

def do_eval(renderer,query_array,args,visual_model,autoencoder,latent_flow_model,resizer,iter,text_features=None):
out_3d = gen_shapes(query_array,args,visual_model,autoencoder,latent_flow_model,text_features)
#save out_3d to numpy file
# with open(f'out_3d/{args.learning_rate}_{args.query_array}/out_3d_{iter}.npy', 'wb') as f:
# np.save(f, out_3d.cpu().detach().numpy())
os.makedirs(f'out_3d/{args.learning_rate}_{args.query_array}', exist_ok=True)
with open(f'out_3d/{args.learning_rate}_{args.query_array}/out_3d_{iter}.npy', 'wb') as f:
np.save(f, out_3d.cpu().detach().numpy())

out_3d_hard = out_3d.detach() > args.threshold
rgbs_hard = renderer.render(out_3d_hard.float()).double().to(args.device)
Expand Down Expand Up @@ -165,6 +221,8 @@ def evaluate_true_voxel(out_3d,args,visual_model,text_features,i):
# get CLIP embedding
# voxel_image_embedding = clip_model.encode_image(voxel_tensor.to(args.device))
voxel_image_embedding = visual_model(voxel_tensor.to(args.device).type(visual_model_type))
print("voxel_image_embedding",voxel_image_embedding.shape)
print("text_features",text_features.shape)
voxel_similarity = torch.cosine_similarity(text_features, voxel_image_embedding).mean()
return voxel_similarity

Expand Down Expand Up @@ -200,10 +258,13 @@ def get_local_parser(mode="args"):
else:
return parser

def test_train(args,clip_model,autoencoder,latent_flow_model,renderer):
def test_train(args,clip_model,autoencoder,latent_flow_model,renderer):



resizer = T.Resize(224)
flow_optimizer=optim.Adam(latent_flow_model.parameters(), lr=args.learning_rate)
net_optimizer=optim.Adam(autoencoder.parameters(), lr=args.learning_rate)
# flow_optimizer=optim.Adam(latent_flow_model.parameters(), lr=args.learning_rate)
# net_optimizer=optim.Adam(autoencoder.parameters(), lr=args.learning_rate)

losses = []

Expand All @@ -218,30 +279,41 @@ def test_train(args,clip_model,autoencoder,latent_flow_model,renderer):
if not os.path.exists('queries/%s' % args.writer.log_dir[5:]):
os.makedirs('queries/%s' % args.writer.log_dir[5:])

#remove text components from clip and free up memory
visual_model = clip_model.visual
del clip_model
# model
net = Model(args,clip_model,autoencoder,latent_flow_model,renderer,resizer,query_array)

net_optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

#set gradient of clip model to false
for param in visual_model.parameters():
param.requires_grad = False
visual_model.eval()
torch.cuda.empty_cache()
net = nn.DataParallel(net)

# print("Done")

# #remove text components from clip and free up memory
# visual_model = clip_model.visual
# del clip_model

global visual_model_type
visual_model_type = get_type(visual_model)
visual_model = nn.DataParallel(visual_model)
# #set gradient of clip model to false
# for param in visual_model.parameters():
# param.requires_grad = False
# visual_model.eval()
# torch.cuda.empty_cache()

# global visual_model_type
# visual_model_type = get_type(visual_model)
# visual_model = nn.DataParallel(visual_model)

for iter in range(20000):
if not iter%300:
do_eval(renderer,query_array,args,visual_model,autoencoder,latent_flow_model,resizer,iter,text_features)
# if not iter%100:
# do_eval(renderer,query_array,args,visual_model,autoencoder,latent_flow_model,resizer,iter,text_features)


flow_optimizer.zero_grad()
# flow_optimizer.zero_grad()
net_optimizer.zero_grad()

loss = clip_loss(args,query_array,visual_model,autoencoder,latent_flow_model,renderer,resizer,iter,text_features)
# loss = clip_loss(args,query_array,visual_model,autoencoder,latent_flow_model,renderer,resizer,iter,text_features)

loss = net(text_features)

loss.backward()

losses.append(loss.item())
Expand Down Expand Up @@ -282,8 +354,6 @@ def main(args):

latent_flow_network = latent_flows.get_generator(args.emb_dims, args.cond_emb_dim, device, flow_type=args.flow_type, num_blocks=args.num_blocks, num_hidden=args.num_hidden)
if not args.uninitialized:
print(args.checkpoint_dir_prior)
print(args.checkpoint)
sys.stdout.flush()
checkpoint_nf_path = os.path.join(args.checkpoint_dir_prior, args.checkpoint_nf +".pt")
checkpoint = torch.load(checkpoint_nf_path, map_location=args.device)
Expand Down
10 changes: 5 additions & 5 deletions networks/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, args):



def decoding(self, shape_embedding, points=None):
def forward(self, shape_embedding, points=None):
if self.output_type == "Pointcloud":
return self.decoder(shape_embedding)
else:
Expand All @@ -259,7 +259,7 @@ def reconstruction_loss(self, pred, gt):
return loss


def forward(self, data_input, query_points=None):
shape_embs = self.encoder(data_input)
pred = self.decoding(shape_embs, points=query_points)
return pred, shape_embs
# def forward(self, data_input, query_points=None):
# shape_embs = self.encoder(data_input)
# pred = self.decoding(shape_embs, points=query_points)
# return pred, shape_embs
7 changes: 7 additions & 0 deletions scripts/debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD";

python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "fork" --num_views 1 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+
26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt0.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt0
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt0_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "fork" --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt1
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt1_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "A picture of a fork" --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt2
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt1_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "A rendering of a fork" --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt3
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt1_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "A fork on a blue and white background" --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt4
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt1_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "A fork typically has four tines, is made of metal, and is silver in color." --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




26 changes: 26 additions & 0 deletions scripts/prompt/run_fork_prompt5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=6
#SBATCH --gres=gpu:a100:2
#SBATCH --time=47:59:00
#SBATCH --mem=64GB
#SBATCH --job-name=clip_forge_prompt5
#SBATCH --mail-type=ALL
#SBATCH --mail-user=mp5847@nyu.edu
#SBATCH --output=clip_forge_prompt1_%j.out

module purge

singularity exec --nv \
--overlay /scratch/km3888/singularity_forge/3d.ext3:ro \
/scratch/work/public/singularity/cuda11.2.2-cudnn8-devel-ubuntu20.04.sif \
/bin/bash -c 'source /ext3/env.sh; cd /home/mp5847/src/general_clip_forge; export PYTHONPATH="$PYTHONPATH:$PWD"; \
python continued_training.py --num_voxels 128 --learning_rate 01e-05 \
--gpu 0 --beta 150.0 --query_array "A fork is a utensil with a handle and usually four tines or prongs at the other end, designed in various materials and with decorative patterns." --num_views 10 \
--checkpoint_dir_base /scratch/mp5847/general_clip_forge/exps/models/autoencoder \
--checkpoint best_iou --checkpoint_dir_prior \
/scratch/mp5847/general_clip_forge/exps/models/prior --checkpoint_nf best --renderer nvr+'




Loading