Permalink
Browse files

allreduce working in windows now, with help from Robert

  • Loading branch information...
1 parent ca2bee8 commit 6ec912d53b83823c7b5dc3e3824db87d66ee383d U-NORTHAMERICA\jcl committed Sep 27, 2012
Showing with 67 additions and 47 deletions.
  1. +30 −18 cluster/spanning_tree.cc
  2. +37 −29 vowpalwabbit/allreduce.cc
View
@@ -12,13 +12,14 @@ This creates a binary tree topology over a set of n nodes that connect.
#include <Windows.h>
#include <io.h>
+#define SHUT_RDWR SD_BOTH
+
typedef unsigned int uint32_t;
typedef unsigned short uint16_t;
typedef int socklen_t;
int daemon(int a, int b)
{
- exit(0);
return 0;
}
int getpid()
@@ -94,11 +95,11 @@ int build_tree(int* parent, uint16_t* kid_count, int source_count, int offset)
return oroot;
}
-void fail_write(int fd, const void* buf, size_t count)
+void fail_send(int fd, const void* buf, size_t count)
{
- if (write(fd,buf,count)==-1)
+ if (send(fd,(char*)buf,count,0)==-1)
{
- cerr << "write failed!" << endl;
+ cerr << "send failed!" << endl;
exit(1);
}
}
@@ -110,9 +111,17 @@ int main(int argc, char* argv[]) {
exit(0);
}
+#ifdef _WIN32
+ WSAData wsaData;
+ WSAStartup(MAKEWORD(2,2), &wsaData);
+ int lastError = WSAGetLastError();
+#endif
+
int sock = socket(PF_INET, SOCK_STREAM, 0);
if (sock < 0) {
- cerr << "can't open socket!" << endl;
+ lastError = WSAGetLastError();
+
+ cerr << "can't open socket! (" << lastError << ")" << endl;
exit(1);
}
@@ -152,7 +161,6 @@ int main(int argc, char* argv[]) {
}
map<int, partial> partial_nodesets;
-
while(true) {
listen(sock, 1024);
@@ -166,19 +174,19 @@ int main(int argc, char* argv[]) {
}
size_t nonce = 0;
- if (read(f, &nonce, sizeof(nonce)) != sizeof(nonce))
+ if (recv(f, (char*)&nonce, sizeof(nonce), 0) != sizeof(nonce))
{
cerr << "nonce read failed, exiting" << endl;
exit(1);
}
size_t total = 0;
- if (read(f, &total, sizeof(total)) != sizeof(total))
+ if (recv(f, (char*)&total, sizeof(total), 0) != sizeof(total))
{
cerr << "total node count read failed, exiting" << endl;
exit(1);
}
size_t id = 0;
- if (read(f, &id, sizeof(id)) != sizeof(id))
+ if (recv(f, (char*)&id, sizeof(id), 0) != sizeof(id))
{
cerr << "node id read failed, exiting" << endl;
exit(1);
@@ -206,7 +214,7 @@ int main(int argc, char* argv[]) {
if (ok && partial_nodeset.nodes[id].client_ip != (uint32_t)-1)
ok = false;
- fail_write(f,&ok, sizeof(ok));
+ fail_send(f,&ok, sizeof(ok));
if (ok)
{
@@ -230,34 +238,38 @@ int main(int argc, char* argv[]) {
for (size_t i = 0; i < total; i++)
{
- fail_write(partial_nodeset.nodes[i].socket, &kid_count[i], sizeof(kid_count[i]));
+ fail_send(partial_nodeset.nodes[i].socket, &kid_count[i], sizeof(kid_count[i]));
}
uint16_t* client_ports=(uint16_t*)calloc(total,sizeof(uint16_t));
for(size_t i = 0;i < total;i++) {
int done = 0;
- if(read(partial_nodeset.nodes[i].socket, &(client_ports[i]), sizeof(client_ports[i])) < (int) sizeof(client_ports[i]))
+ if(recv(partial_nodeset.nodes[i].socket, (char*)&(client_ports[i]), sizeof(client_ports[i]), 0) < (int) sizeof(client_ports[i]))
cerr<<" Port read failed for node "<<i<<" read "<<done<<endl;
}// all clients have bound to their ports.
for (size_t i = 0; i < total; i++)
{
if (parent[i] >= 0)
{
- fail_write(partial_nodeset.nodes[i].socket, &partial_nodeset.nodes[parent[i]].client_ip, sizeof(partial_nodeset.nodes[parent[i]].client_ip));
- fail_write(partial_nodeset.nodes[i].socket, &client_ports[parent[i]], sizeof(client_ports[parent[i]]));
- }
+ fail_send(partial_nodeset.nodes[i].socket, &partial_nodeset.nodes[parent[i]].client_ip, sizeof(partial_nodeset.nodes[parent[i]].client_ip));
+ fail_send(partial_nodeset.nodes[i].socket, &client_ports[parent[i]], sizeof(client_ports[parent[i]]));
+ }
else
{
int bogus = -1;
uint32_t bogus2 = -1;
- fail_write(partial_nodeset.nodes[i].socket, &bogus2, sizeof(bogus2));
- fail_write(partial_nodeset.nodes[i].socket, &bogus, sizeof(bogus));
+ fail_send(partial_nodeset.nodes[i].socket, &bogus2, sizeof(bogus2));
+ fail_send(partial_nodeset.nodes[i].socket, &bogus, sizeof(bogus));
}
- close(partial_nodeset.nodes[i].socket);
+ shutdown(partial_nodeset.nodes[i].socket, SHUT_RDWR);
}
free (partial_nodeset.nodes);
}
}
+
+#ifdef _WIN32
+ WSACleanup();
+#endif
}
View
@@ -16,6 +16,7 @@ Alekh Agarwal and John Langford, with help Olivier Chapelle.
typedef unsigned int uint32_t;
typedef unsigned short uint16_t;
typedef int socklen_t;
+#define SHUT_RDWR SD_BOTH
#else
#include <sys/socket.h>
#include <sys/socket.h>
@@ -119,8 +120,14 @@ int getsock()
void all_reduce_init(string master_location, size_t unique_id, size_t total, size_t node)
{
- struct hostent* master = gethostbyname(master_location.c_str());
-
+#ifdef _WIN32
+ WSAData wsaData;
+ WSAStartup(MAKEWORD(2,2), &wsaData);
+ int lastError = WSAGetLastError();
+#endif
+
+ struct hostent* master = gethostbyname(master_location.c_str());
+
if (master == NULL) {
cerr << "can't resolve hostname: " << master_location << endl;
exit(1);
@@ -130,16 +137,20 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
uint32_t master_ip = * ((uint32_t*)master->h_addr);
int port = 26543;
- int master_sock = sock_connect(master_ip, htons(port));
-
- if(write(master_sock, &unique_id, sizeof(unique_id)) < (int)sizeof(unique_id))
- cerr << "write failed!" << endl;
- if(write(master_sock, &total, sizeof(total)) < (int)sizeof(total))
+#ifdef _WIN32
+ SOCKET master_sock;
+#else
+ int master_sock;
+#endif
+ master_sock = sock_connect(master_ip, htons(port));
+ if(send(master_sock, (const char*)&unique_id, sizeof(unique_id), 0) < (int)sizeof(unique_id))
+ cerr << "write failed!" << WSAGetLastError() << endl;
+ if(send(master_sock, (const char*)&total, sizeof(total), 0) < (int)sizeof(total))
cerr << "write failed!" << endl;
- if(write(master_sock, &node, sizeof(node)) < (int)sizeof(node))
+ if(send(master_sock, (char*)&node, sizeof(node), 0) < (int)sizeof(node))
cerr << "write failed!" << endl;
int ok;
- if (read(master_sock, &ok, sizeof(ok)) < (int)sizeof(ok))
+ if (recv(master_sock, (char*)&ok, sizeof(ok), 0) < (int)sizeof(ok))
cerr << "read 1 failed!" << endl;
if (!ok) {
cerr << "mapper already connected" << endl;
@@ -150,7 +161,7 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
uint16_t parent_port;
uint32_t parent_ip;
- if(read(master_sock, &kid_count, sizeof(kid_count)) < (int)sizeof(kid_count))
+ if(recv(master_sock, (char*)&kid_count, sizeof(kid_count), 0) < (int)sizeof(kid_count))
cerr << "read 2 failed!" << endl;
int sock = -1;
@@ -165,7 +176,7 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
bool listening = false;
while(!listening)
{
- if (bind(sock,(sockaddr*)&address, sizeof(address)) < 0)
+ if (bind(sock,(sockaddr*)&address, sizeof(address)) < 0)
if (errno == EADDRINUSE)
{
netport = htons(ntohs(netport)+1);
@@ -180,7 +191,7 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
if (listen(sock, kid_count) < 0)
{
perror("listen failed! ");
- close(sock);
+ shutdown(sock, SHUT_RDWR);
sock = getsock();
}
else
@@ -190,16 +201,15 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
}
}
- if(write(master_sock, &netport, sizeof(netport)) < (int)sizeof(netport))
+ if(send(master_sock, (const char*)&netport, sizeof(netport), 0) < (int)sizeof(netport))
cerr << "write failed!" << endl;
- if(read(master_sock, &parent_ip, sizeof(parent_ip)) < (int)sizeof(parent_ip))
+ if(recv(master_sock, (char*)&parent_ip, sizeof(parent_ip), 0) < (int)sizeof(parent_ip))
cerr << "read 3 failed!" << endl;
- if(read(master_sock, &parent_port, sizeof(parent_port)) < (int)sizeof(parent_port))
+ if(recv(master_sock, (char*)&parent_port, sizeof(parent_port), 0) < (int)sizeof(parent_port))
cerr << "read 4 failed!" << endl;
-
- close(master_sock);
+ shutdown(master_sock, SD_BOTH);
//int parent_sock;
if(parent_ip != (uint32_t)-1)
@@ -208,7 +218,6 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
socks.parent = -1;
socks.children[0] = -1; socks.children[1] = -1;
-
for (int i = 0; i < kid_count; i++)
{
sockaddr_in child_address;
@@ -223,7 +232,7 @@ void all_reduce_init(string master_location, size_t unique_id, size_t total, siz
}
if (kid_count > 0)
- close(sock);
+ shutdown(sock, SHUT_RDWR);
}
void addbufs(float* buf1, float* buf2, int n) {
@@ -246,7 +255,7 @@ void pass_up(char* buffer, int left_read_pos, int right_read_pos, int& parent_se
if(my_bufsize > 0) {
//going to pass up this chunk of data to the parent
- int write_size = write(parent_sock, buffer+parent_sent_pos, my_bufsize);
+ int write_size = send(parent_sock, buffer+parent_sent_pos, my_bufsize, 0);
if(write_size < my_bufsize)
cerr<<"Write to parent failed "<<my_bufsize<<" "<<write_size<<" "<<parent_sent_pos<<" "<<left_read_pos<<" "<<right_read_pos<<endl ;
parent_sent_pos += my_bufsize;
@@ -262,9 +271,9 @@ void pass_down(char* buffer, int parent_read_pos, int&children_sent_pos, int* ch
if(my_bufsize > 0) {
//going to pass up this chunk of data to the children
- if(child_sockets[0] != -1 && write(child_sockets[0], buffer+children_sent_pos, my_bufsize) < my_bufsize)
+ if(child_sockets[0] != -1 && send(child_sockets[0], buffer+children_sent_pos, my_bufsize, 0) < my_bufsize)
cerr<<"Write to left child failed\n";
- if(child_sockets[1] != -1 && write(child_sockets[1], buffer+children_sent_pos, my_bufsize) < my_bufsize)
+ if(child_sockets[1] != -1 && send(child_sockets[1], buffer+children_sent_pos, my_bufsize, 0) < my_bufsize)
cerr<<"Write to right child failed\n";
children_sent_pos += my_bufsize;
@@ -326,7 +335,7 @@ void reduce(char* buffer, int n, int parent_sock, int* child_sockets) {
//float read_buf[buf_size];
size_t count = min(buf_size,n - child_read_pos[i]);
- int read_size = read(child_sockets[i], child_read_buf[i] + child_unprocessed[i], count);
+ int read_size = recv(child_sockets[i], child_read_buf[i] + child_unprocessed[i], count, 0);
if(read_size == -1) {
cerr <<" Read from child failed\n";
perror(NULL);
@@ -398,7 +407,7 @@ void broadcast(char* buffer, int n, int parent_sock, int* child_sockets) {
exit(1);
}
size_t count = min(buf_size,n-parent_read_pos);
- int read_size = read(parent_sock, buffer + parent_read_pos, count);
+ int read_size = recv(parent_sock, buffer + parent_read_pos, count, 0);
if(read_size == -1) {
cerr <<" Read from parent failed\n";
perror(NULL);
@@ -412,20 +421,19 @@ void all_reduce(float* buffer, int n, string master_location, size_t unique_id,
{
if(master_location != current_master)
all_reduce_init(master_location, unique_id, total, node);
-
reduce((char*)buffer, n*sizeof(float), socks.parent, socks.children);
broadcast((char*)buffer, n*sizeof(float), socks.parent, socks.children);
}
node_socks::~node_socks()
{
- if(current_master != "") {
+ if(current_master != "") {
if(this->parent != -1)
- close(this->parent);
+ shutdown(this->parent, SHUT_RDWR);
if(this->children[0] != -1)
- close(this->children[0]);
+ shutdown(this->children[0], SHUT_RDWR);
if(this->children[1] != -1)
- close(this->children[1]);
+ shutdown(this->children[1], SHUT_RDWR);
}
}

0 comments on commit 6ec912d

Please sign in to comment.